Change locking boundaries for shared dict. Add a unit test.
[python_utils.git] / bootstrap.py
1 #!/usr/bin/env python3
2
3 import functools
4 import logging
5 import os
6 from inspect import stack
7 import sys
8
9 # This module is commonly used by others in here and should avoid
10 # taking any unnecessary dependencies back on them.
11
12 from argparse_utils import ActionNoYes
13 import config
14 import logging_utils
15
16 logger = logging.getLogger(__name__)
17
18 args = config.add_commandline_args(
19     f'Bootstrap ({__file__})',
20     'Args related to python program bootstrapper and Swiss army knife',
21 )
22 args.add_argument(
23     '--debug_unhandled_exceptions',
24     action=ActionNoYes,
25     default=False,
26     help='Break into pdb on top level unhandled exceptions.',
27 )
28 args.add_argument(
29     '--show_random_seed',
30     action=ActionNoYes,
31     default=False,
32     help='Should we display (and log.debug) the global random seed?',
33 )
34 args.add_argument(
35     '--set_random_seed',
36     type=int,
37     nargs=1,
38     default=None,
39     metavar='SEED_INT',
40     help='Override the global random seed with a particular number.',
41 )
42 args.add_argument(
43     '--dump_all_objects',
44     action=ActionNoYes,
45     default=False,
46     help='Should we dump the Python import tree before main?',
47 )
48 args.add_argument(
49     '--audit_import_events',
50     action=ActionNoYes,
51     default=False,
52     help='Should we audit all import events?',
53 )
54 args.add_argument(
55     '--run_profiler',
56     action=ActionNoYes,
57     default=False,
58     help='Should we run cProfile on this code?',
59 )
60 args.add_argument(
61     '--trace_memory',
62     action=ActionNoYes,
63     default=False,
64     help='Should we record/report on memory utilization?',
65 )
66
67 original_hook = sys.excepthook
68
69
70 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
71     """
72     Top-level exception handler for exceptions that make it past any exception
73     handlers in the python code being run.  Logs the error and stacktrace then
74     maybe attaches a debugger.
75
76     """
77     global original_hook
78     msg = f'Unhandled top level exception {exc_type}'
79     logger.exception(msg)
80     print(msg, file=sys.stderr)
81     if issubclass(exc_type, KeyboardInterrupt):
82         sys.__excepthook__(exc_type, exc_value, exc_tb)
83         return
84     else:
85         if not sys.stderr.isatty() or not sys.stdin.isatty():
86             # stdin or stderr is redirected, just do the normal thing
87             original_hook(exc_type, exc_value, exc_tb)
88         else:
89             # a terminal is attached and stderr is not redirected, maybe debug.
90             import traceback
91
92             traceback.print_exception(exc_type, exc_value, exc_tb)
93             if config.config['debug_unhandled_exceptions']:
94                 import pdb
95
96                 logger.info("Invoking the debugger...")
97                 pdb.pm()
98             else:
99                 original_hook(exc_type, exc_value, exc_tb)
100
101
102 class ImportInterceptor(object):
103     def __init__(self):
104         import collect.trie
105
106         self.module_by_filename_cache = {}
107         self.repopulate_modules_by_filename()
108         self.tree = collect.trie.Trie()
109         self.tree_node_by_module = {}
110
111     def repopulate_modules_by_filename(self):
112         self.module_by_filename_cache.clear()
113         for mod in sys.modules:
114             if hasattr(sys.modules[mod], '__file__'):
115                 fname = getattr(sys.modules[mod], '__file__')
116             else:
117                 fname = 'unknown'
118             self.module_by_filename_cache[fname] = mod
119
120     def should_ignore_filename(self, filename: str) -> bool:
121         return 'importlib' in filename or 'six.py' in filename
122
123     def find_spec(self, loaded_module, path=None, target=None):
124         s = stack()
125         for x in range(3, len(s)):
126             filename = s[x].filename
127             if self.should_ignore_filename(filename):
128                 continue
129
130             loading_function = s[x].function
131             if filename in self.module_by_filename_cache:
132                 loading_module = self.module_by_filename_cache[filename]
133             else:
134                 self.repopulate_modules_by_filename()
135                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
136
137             path = self.tree_node_by_module.get(loading_module, [])
138             path.extend([loaded_module])
139             self.tree.insert(path)
140             self.tree_node_by_module[loading_module] = path
141
142             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
143             logger.debug(msg)
144             print(msg)
145             return
146         msg = f'*** Import {loaded_module} from ?????'
147         logger.debug(msg)
148         print(msg)
149
150     def find_importer(self, module: str):
151         if module in self.tree_node_by_module:
152             node = self.tree_node_by_module[module]
153             return node
154         return []
155
156
157 # Audit import events?  Note: this runs early in the lifetime of the
158 # process (assuming that import bootstrap happens early); config has
159 # (probably) not yet been loaded or parsed the commandline.  Also,
160 # some things have probably already been imported while we weren't
161 # watching so this information may be incomplete.
162 #
163 # Also note: move bootstrap up in the global import list to catch
164 # more import events and have a more complete record.
165 import_interceptor = None
166 for arg in sys.argv:
167     if arg == '--audit_import_events':
168         import_interceptor = ImportInterceptor()
169         sys.meta_path = [import_interceptor] + sys.meta_path
170
171
172 def dump_all_objects() -> None:
173     global import_interceptor
174     messages = {}
175     all_modules = sys.modules
176     for obj in object.__subclasses__():
177         if not hasattr(obj, '__name__'):
178             continue
179         klass = obj.__name__
180         if not hasattr(obj, '__module__'):
181             continue
182         class_mod_name = obj.__module__
183         if class_mod_name in all_modules:
184             mod = all_modules[class_mod_name]
185             if not hasattr(mod, '__name__'):
186                 mod_name = class_mod_name
187             else:
188                 mod_name = mod.__name__
189             if hasattr(mod, '__file__'):
190                 mod_file = mod.__file__
191             else:
192                 mod_file = 'unknown'
193             if import_interceptor is not None:
194                 import_path = import_interceptor.find_importer(mod_name)
195             else:
196                 import_path = 'unknown'
197             msg = f'{class_mod_name}::{klass} ({mod_file})'
198             if import_path != 'unknown' and len(import_path) > 0:
199                 msg += f' imported by {import_path}'
200             messages[f'{class_mod_name}::{klass}'] = msg
201     for x in sorted(messages.keys()):
202         logger.debug(messages[x])
203         print(messages[x])
204
205
206 def initialize(entry_point):
207     """
208     Remember to initialize config, initialize logging, set/log a random
209     seed, etc... before running main.
210
211     """
212
213     @functools.wraps(entry_point)
214     def initialize_wrapper(*args, **kwargs):
215         # Hook top level unhandled exceptions, maybe invoke debugger.
216         if sys.excepthook == sys.__excepthook__:
217             sys.excepthook = handle_uncaught_exception
218
219         # Try to figure out the name of the program entry point.  Then
220         # parse configuration (based on cmdline flags, environment vars
221         # etc...)
222         if (
223             '__globals__' in entry_point.__dict__
224             and '__file__' in entry_point.__globals__
225         ):
226             config.parse(entry_point.__globals__['__file__'])
227         else:
228             config.parse(None)
229
230         if config.config['trace_memory']:
231             import tracemalloc
232
233             tracemalloc.start()
234
235         # Initialize logging... and log some remembered messages from
236         # config module.
237         logging_utils.initialize_logging(logging.getLogger())
238         config.late_logging()
239
240         # Maybe log some info about the python interpreter itself.
241         logger.debug(
242             f'Platform: {sys.platform}, maxint=0x{sys.maxsize:x}, byteorder={sys.byteorder}'
243         )
244         logger.debug(f'Python interpreter version: {sys.version}')
245         logger.debug(f'Python implementation: {sys.implementation}')
246         logger.debug(f'Python C API version: {sys.api_version}')
247         logger.debug(f'Python path: {sys.path}')
248
249         # Allow programs that don't bother to override the random seed
250         # to be replayed via the commandline.
251         import random
252
253         random_seed = config.config['set_random_seed']
254         if random_seed is not None:
255             random_seed = random_seed[0]
256         else:
257             random_seed = int.from_bytes(os.urandom(4), 'little')
258
259         if config.config['show_random_seed']:
260             msg = f'Global random seed is: {random_seed}'
261             logger.debug(msg)
262             print(msg)
263         random.seed(random_seed)
264
265         # Do it, invoke the user's code.  Pay attention to how long it takes.
266         logger.debug(f'Starting {entry_point.__name__} (program entry point)')
267         ret = None
268         import stopwatch
269
270         if config.config['run_profiler']:
271             import cProfile
272             from pstats import SortKey
273
274             with stopwatch.Timer() as t:
275                 cProfile.runctx(
276                     "ret = entry_point(*args, **kwargs)",
277                     globals(),
278                     locals(),
279                     None,
280                     SortKey.CUMULATIVE,
281                 )
282         else:
283             with stopwatch.Timer() as t:
284                 ret = entry_point(*args, **kwargs)
285
286         logger.debug(f'{entry_point.__name__} (program entry point) returned {ret}.')
287
288         if config.config['trace_memory']:
289             snapshot = tracemalloc.take_snapshot()
290             top_stats = snapshot.statistics('lineno')
291             print()
292             print("--trace_memory's top 10 memory using files:")
293             for stat in top_stats[:10]:
294                 print(stat)
295
296         if config.config['dump_all_objects']:
297             dump_all_objects()
298
299         if config.config['audit_import_events']:
300             global import_interceptor
301             if import_interceptor is not None:
302                 print(import_interceptor.tree)
303
304         walltime = t()
305         (utime, stime, cutime, cstime, elapsed_time) = os.times()
306         logger.debug(
307             '\n'
308             f'user: {utime}s\n'
309             f'system: {stime}s\n'
310             f'child user: {cutime}s\n'
311             f'child system: {cstime}s\n'
312             f'machine uptime: {elapsed_time}s\n'
313             f'walltime: {walltime}s'
314         )
315
316         # If it doesn't return cleanly, call attention to the return value.
317         if ret is not None and ret != 0:
318             logger.error(f'Exit {ret}')
319         else:
320             logger.debug(f'Exit {ret}')
321         sys.exit(ret)
322
323     return initialize_wrapper