12a5028c30e2bf95093542296bb1cf5a9866f879
[python_utils.git] / remote_worker.py
1 #!/usr/bin/env python3
2
3 """A simple utility to unpickle some code, run it, and pickle the
4 results.
5 """
6
7 import logging
8 import os
9 import signal
10 import threading
11 import sys
12 import time
13 from typing import Optional
14
15 import cloudpickle  # type: ignore
16 import psutil  # type: ignore
17
18 import argparse_utils
19 import bootstrap
20 import config
21 from stopwatch import Timer
22 from thread_utils import background_thread
23
24
25 logger = logging.getLogger(__file__)
26
27 cfg = config.add_commandline_args(
28     f"Remote Worker ({__file__})",
29     "Helper to run pickled code remotely and return results",
30 )
31 cfg.add_argument(
32     '--code_file',
33     type=str,
34     required=True,
35     metavar='FILENAME',
36     help='The location of the bundle of code to execute.',
37 )
38 cfg.add_argument(
39     '--result_file',
40     type=str,
41     required=True,
42     metavar='FILENAME',
43     help='The location where we should write the computation results.',
44 )
45 cfg.add_argument(
46     '--watch_for_cancel',
47     action=argparse_utils.ActionNoYes,
48     default=True,
49     help='Should we watch for the cancellation of our parent ssh process?',
50 )
51
52
53 @background_thread
54 def watch_for_cancel(terminate_event: threading.Event) -> None:
55     logger.debug('Starting up background thread...')
56     p = psutil.Process(os.getpid())
57     while True:
58         saw_sshd = False
59         ancestors = p.parents()
60         for ancestor in ancestors:
61             name = ancestor.name()
62             pid = ancestor.pid
63             logger.debug(f'Ancestor process {name} (pid={pid})')
64             if 'ssh' in name.lower():
65                 saw_sshd = True
66                 break
67         if not saw_sshd:
68             logger.error(
69                 'Did not see sshd in our ancestors list?!  Committing suicide.'
70             )
71             os.system('pstree')
72             os.kill(os.getpid(), signal.SIGTERM)
73             time.sleep(5.0)
74             os.kill(os.getpid(), signal.SIGKILL)
75             sys.exit(-1)
76         if terminate_event.is_set():
77             return
78         time.sleep(1.0)
79
80
81 def cleanup_and_exit(
82     thread: Optional[threading.Thread],
83     stop_thread: Optional[threading.Event],
84     exit_code: int,
85 ) -> None:
86     if stop_thread is not None:
87         stop_thread.set()
88         assert thread is not None
89         thread.join()
90     sys.exit(exit_code)
91
92
93 @bootstrap.initialize
94 def main() -> None:
95     in_file = config.config['code_file']
96     out_file = config.config['result_file']
97
98     thread = None
99     stop_thread = None
100     if config.config['watch_for_cancel']:
101         (thread, stop_thread) = watch_for_cancel()
102
103     logger.debug(f'Reading {in_file}.')
104     try:
105         with open(in_file, 'rb') as rb:
106             serialized = rb.read()
107     except Exception as e:
108         logger.exception(e)
109         logger.critical(f'Problem reading {in_file}.  Aborting.')
110         cleanup_and_exit(thread, stop_thread, 1)
111
112     logger.debug(f'Deserializing {in_file}.')
113     try:
114         fun, args, kwargs = cloudpickle.loads(serialized)
115     except Exception as e:
116         logger.exception(e)
117         logger.critical(f'Problem deserializing {in_file}.  Aborting.')
118         cleanup_and_exit(thread, stop_thread, 2)
119
120     logger.debug('Invoking user code...')
121     with Timer() as t:
122         ret = fun(*args, **kwargs)
123     logger.debug(f'User code took {t():.1f}s')
124
125     logger.debug('Serializing results')
126     try:
127         serialized = cloudpickle.dumps(ret)
128     except Exception as e:
129         logger.exception(e)
130         logger.critical(f'Could not serialize result ({type(ret)}).  Aborting.')
131         cleanup_and_exit(thread, stop_thread, 3)
132
133     logger.debug(f'Writing {out_file}.')
134     try:
135         with open(out_file, 'wb') as wb:
136             wb.write(serialized)
137     except Exception as e:
138         logger.exception(e)
139         logger.critical(f'Error writing {out_file}.  Aborting.')
140         cleanup_and_exit(thread, stop_thread, 4)
141     cleanup_and_exit(thread, stop_thread, 0)
142
143
144 if __name__ == '__main__':
145     main()