75dfe8e46cb3d34e28f44267129ff38020d23a7a
[python_utils.git] / remote_worker.py
1 #!/usr/bin/env python3
2
3 """A simple utility to unpickle some code, run it, and pickle the
4 results.
5 """
6
7 import logging
8 import os
9 import signal
10 import sys
11 import threading
12 import time
13 from typing import Optional
14
15 import cloudpickle  # type: ignore
16 import psutil  # type: ignore
17
18 import argparse_utils
19 import bootstrap
20 import config
21 from stopwatch import Timer
22 from thread_utils import background_thread
23
24 logger = logging.getLogger(__file__)
25
26 cfg = config.add_commandline_args(
27     f"Remote Worker ({__file__})",
28     "Helper to run pickled code remotely and return results",
29 )
30 cfg.add_argument(
31     '--code_file',
32     type=str,
33     required=True,
34     metavar='FILENAME',
35     help='The location of the bundle of code to execute.',
36 )
37 cfg.add_argument(
38     '--result_file',
39     type=str,
40     required=True,
41     metavar='FILENAME',
42     help='The location where we should write the computation results.',
43 )
44 cfg.add_argument(
45     '--watch_for_cancel',
46     action=argparse_utils.ActionNoYes,
47     default=True,
48     help='Should we watch for the cancellation of our parent ssh process?',
49 )
50
51
52 @background_thread
53 def watch_for_cancel(terminate_event: threading.Event) -> None:
54     logger.debug('Starting up background thread...')
55     p = psutil.Process(os.getpid())
56     while True:
57         saw_sshd = False
58         ancestors = p.parents()
59         for ancestor in ancestors:
60             name = ancestor.name()
61             pid = ancestor.pid
62             logger.debug('Ancestor process %s (pid=%d)', name, pid)
63             if 'ssh' in name.lower():
64                 saw_sshd = True
65                 break
66         if not saw_sshd:
67             logger.error('Did not see sshd in our ancestors list?!  Committing suicide.')
68             os.system('pstree')
69             os.kill(os.getpid(), signal.SIGTERM)
70             time.sleep(5.0)
71             os.kill(os.getpid(), signal.SIGKILL)
72             sys.exit(-1)
73         if terminate_event.is_set():
74             return
75         time.sleep(1.0)
76
77
78 def cleanup_and_exit(
79     thread: Optional[threading.Thread],
80     stop_thread: Optional[threading.Event],
81     exit_code: int,
82 ) -> None:
83     if stop_thread is not None:
84         stop_thread.set()
85         assert thread is not None
86         thread.join()
87     sys.exit(exit_code)
88
89
90 @bootstrap.initialize
91 def main() -> None:
92     in_file = config.config['code_file']
93     out_file = config.config['result_file']
94
95     thread = None
96     stop_thread = None
97     if config.config['watch_for_cancel']:
98         (thread, stop_thread) = watch_for_cancel()
99
100     logger.debug('Reading %s.', in_file)
101     try:
102         with open(in_file, 'rb') as rb:
103             serialized = rb.read()
104     except Exception as e:
105         logger.exception(e)
106         logger.critical('Problem reading %s. Aborting.', in_file)
107         cleanup_and_exit(thread, stop_thread, 1)
108
109     logger.debug('Deserializing %s', in_file)
110     try:
111         fun, args, kwargs = cloudpickle.loads(serialized)
112     except Exception as e:
113         logger.exception(e)
114         logger.critical('Problem deserializing %s. Aborting.', in_file)
115         cleanup_and_exit(thread, stop_thread, 2)
116
117     logger.debug('Invoking user code...')
118     with Timer() as t:
119         ret = fun(*args, **kwargs)
120     logger.debug('User code took %.1fs', t())
121
122     logger.debug('Serializing results')
123     try:
124         serialized = cloudpickle.dumps(ret)
125     except Exception as e:
126         logger.exception(e)
127         logger.critical('Could not serialize result (%s). Aborting.', type(ret))
128         cleanup_and_exit(thread, stop_thread, 3)
129
130     logger.debug('Writing %s', out_file)
131     try:
132         with open(out_file, 'wb') as wb:
133             wb.write(serialized)
134     except Exception as e:
135         logger.exception(e)
136         logger.critical('Error writing %s. Aborting.', out_file)
137         cleanup_and_exit(thread, stop_thread, 4)
138     cleanup_and_exit(thread, stop_thread, 0)
139
140
141 if __name__ == '__main__':
142     main()