1b4d4843332effc7683bcbd20a7f9a1432cc299e
[python_utils.git] / bootstrap.py
1 #!/usr/bin/env python3
2
3 import functools
4 import logging
5 import os
6 from inspect import stack
7 import sys
8
9 # This module is commonly used by others in here and should avoid
10 # taking any unnecessary dependencies back on them.
11
12 from argparse_utils import ActionNoYes
13 import config
14 import logging_utils
15
16 logger = logging.getLogger(__name__)
17
18 args = config.add_commandline_args(
19     f'Bootstrap ({__file__})',
20     'Args related to python program bootstrapper and Swiss army knife',
21 )
22 args.add_argument(
23     '--debug_unhandled_exceptions',
24     action=ActionNoYes,
25     default=False,
26     help='Break into pdb on top level unhandled exceptions.',
27 )
28 args.add_argument(
29     '--show_random_seed',
30     action=ActionNoYes,
31     default=False,
32     help='Should we display (and log.debug) the global random seed?',
33 )
34 args.add_argument(
35     '--set_random_seed',
36     type=int,
37     nargs=1,
38     default=None,
39     metavar='SEED_INT',
40     help='Override the global random seed with a particular number.',
41 )
42 args.add_argument(
43     '--dump_all_objects',
44     action=ActionNoYes,
45     default=False,
46     help='Should we dump the Python import tree before main?',
47 )
48 args.add_argument(
49     '--audit_import_events',
50     action=ActionNoYes,
51     default=False,
52     help='Should we audit all import events?',
53 )
54 args.add_argument(
55     '--run_profiler',
56     action=ActionNoYes,
57     default=False,
58     help='Should we run cProfile on this code?',
59 )
60
61 original_hook = sys.excepthook
62
63
64 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
65     """
66     Top-level exception handler for exceptions that make it past any exception
67     handlers in the python code being run.  Logs the error and stacktrace then
68     maybe attaches a debugger.
69
70     """
71     global original_hook
72     msg = f'Unhandled top level exception {exc_type}'
73     logger.exception(msg)
74     print(msg, file=sys.stderr)
75     if issubclass(exc_type, KeyboardInterrupt):
76         sys.__excepthook__(exc_type, exc_value, exc_tb)
77         return
78     else:
79         if not sys.stderr.isatty() or not sys.stdin.isatty():
80             # stdin or stderr is redirected, just do the normal thing
81             original_hook(exc_type, exc_value, exc_tb)
82         else:
83             # a terminal is attached and stderr is not redirected, maybe debug.
84             import traceback
85
86             traceback.print_exception(exc_type, exc_value, exc_tb)
87             if config.config['debug_unhandled_exceptions']:
88                 import pdb
89
90                 logger.info("Invoking the debugger...")
91                 pdb.pm()
92             else:
93                 original_hook(exc_type, exc_value, exc_tb)
94
95
96 class ImportInterceptor(object):
97     def __init__(self):
98         import collect.trie
99
100         self.module_by_filename_cache = {}
101         self.repopulate_modules_by_filename()
102         self.tree = collect.trie.Trie()
103         self.tree_node_by_module = {}
104
105     def repopulate_modules_by_filename(self):
106         self.module_by_filename_cache.clear()
107         for mod in sys.modules:
108             if hasattr(sys.modules[mod], '__file__'):
109                 fname = getattr(sys.modules[mod], '__file__')
110             else:
111                 fname = 'unknown'
112             self.module_by_filename_cache[fname] = mod
113
114     def should_ignore_filename(self, filename: str) -> bool:
115         return 'importlib' in filename or 'six.py' in filename
116
117     def find_spec(self, loaded_module, path=None, target=None):
118         s = stack()
119         for x in range(3, len(s)):
120             filename = s[x].filename
121             if self.should_ignore_filename(filename):
122                 continue
123
124             loading_function = s[x].function
125             if filename in self.module_by_filename_cache:
126                 loading_module = self.module_by_filename_cache[filename]
127             else:
128                 self.repopulate_modules_by_filename()
129                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
130
131             path = self.tree_node_by_module.get(loading_module, [])
132             path.extend([loaded_module])
133             self.tree.insert(path)
134             self.tree_node_by_module[loading_module] = path
135
136             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
137             logger.debug(msg)
138             print(msg)
139             return
140         msg = f'*** Import {loaded_module} from ?????'
141         logger.debug(msg)
142         print(msg)
143
144     def find_importer(self, module: str):
145         if module in self.tree_node_by_module:
146             node = self.tree_node_by_module[module]
147             return node
148         return []
149
150
151 # Audit import events?  Note: this runs early in the lifetime of the
152 # process (assuming that import bootstrap happens early); config has
153 # (probably) not yet been loaded or parsed the commandline.  Also,
154 # some things have probably already been imported while we weren't
155 # watching so this information may be incomplete.
156 #
157 # Also note: move bootstrap up in the global import list to catch
158 # more import events and have a more complete record.
159 import_interceptor = None
160 for arg in sys.argv:
161     if arg == '--audit_import_events':
162         import_interceptor = ImportInterceptor()
163         sys.meta_path = [import_interceptor] + sys.meta_path
164
165
166 def dump_all_objects() -> None:
167     global import_interceptor
168     messages = {}
169     all_modules = sys.modules
170     for obj in object.__subclasses__():
171         if not hasattr(obj, '__name__'):
172             continue
173         klass = obj.__name__
174         if not hasattr(obj, '__module__'):
175             continue
176         class_mod_name = obj.__module__
177         if class_mod_name in all_modules:
178             mod = all_modules[class_mod_name]
179             if not hasattr(mod, '__name__'):
180                 mod_name = class_mod_name
181             else:
182                 mod_name = mod.__name__
183             if hasattr(mod, '__file__'):
184                 mod_file = mod.__file__
185             else:
186                 mod_file = 'unknown'
187             if import_interceptor is not None:
188                 import_path = import_interceptor.find_importer(mod_name)
189             else:
190                 import_path = 'unknown'
191             msg = f'{class_mod_name}::{klass} ({mod_file})'
192             if import_path != 'unknown' and len(import_path) > 0:
193                 msg += f' imported by {import_path}'
194             messages[f'{class_mod_name}::{klass}'] = msg
195     for x in sorted(messages.keys()):
196         logger.debug(messages[x])
197         print(messages[x])
198
199
200 def initialize(entry_point):
201     """
202     Remember to initialize config, initialize logging, set/log a random
203     seed, etc... before running main.
204
205     """
206
207     @functools.wraps(entry_point)
208     def initialize_wrapper(*args, **kwargs):
209         # Hook top level unhandled exceptions, maybe invoke debugger.
210         if sys.excepthook == sys.__excepthook__:
211             sys.excepthook = handle_uncaught_exception
212
213         # Try to figure out the name of the program entry point.  Then
214         # parse configuration (based on cmdline flags, environment vars
215         # etc...)
216         if (
217             '__globals__' in entry_point.__dict__
218             and '__file__' in entry_point.__globals__
219         ):
220             config.parse(entry_point.__globals__['__file__'])
221         else:
222             config.parse(None)
223
224         # Initialize logging... and log some remembered messages from
225         # config module.
226         logging_utils.initialize_logging(logging.getLogger())
227         config.late_logging()
228
229         # Allow programs that don't bother to override the random seed
230         # to be replayed via the commandline.
231         import random
232
233         random_seed = config.config['set_random_seed']
234         if random_seed is not None:
235             random_seed = random_seed[0]
236         else:
237             random_seed = int.from_bytes(os.urandom(4), 'little')
238
239         if config.config['show_random_seed']:
240             msg = f'Global random seed is: {random_seed}'
241             logger.debug(msg)
242             print(msg)
243         random.seed(random_seed)
244
245         # Do it, invoke the user's code.  Pay attention to how long it takes.
246         logger.debug(f'Starting {entry_point.__name__} (program entry point)')
247         ret = None
248         import stopwatch
249
250         if config.config['run_profiler']:
251             import cProfile
252             from pstats import SortKey
253
254             with stopwatch.Timer() as t:
255                 cProfile.runctx(
256                     "ret = entry_point(*args, **kwargs)",
257                     globals(),
258                     locals(),
259                     None,
260                     SortKey.CUMULATIVE,
261                 )
262         else:
263             with stopwatch.Timer() as t:
264                 ret = entry_point(*args, **kwargs)
265
266         logger.debug(f'{entry_point.__name__} (program entry point) returned {ret}.')
267
268         if config.config['dump_all_objects']:
269             dump_all_objects()
270
271         if config.config['audit_import_events']:
272             global import_interceptor
273             if import_interceptor is not None:
274                 print(import_interceptor.tree)
275
276         walltime = t()
277         (utime, stime, cutime, cstime, elapsed_time) = os.times()
278         logger.debug(
279             '\n'
280             f'user: {utime}s\n'
281             f'system: {stime}s\n'
282             f'child user: {cutime}s\n'
283             f'child system: {cstime}s\n'
284             f'machine uptime: {elapsed_time}s\n'
285             f'walltime: {walltime}s'
286         )
287
288         # If it doesn't return cleanly, call attention to the return value.
289         if ret is not None and ret != 0:
290             logger.error(f'Exit {ret}')
291         else:
292             logger.debug(f'Exit {ret}')
293         sys.exit(ret)
294
295     return initialize_wrapper