211b2132ff07c8ef0caa54b4b68e89e3fe7c8caf
[python_utils.git] / remote_worker.py
1 #!/usr/bin/env python3
2
3 """A simple utility to unpickle some code, run it, and pickle the
4 results.
5 """
6
7 import logging
8 import os
9 import signal
10 import threading
11 import sys
12 import time
13
14 import cloudpickle  # type: ignore
15 import psutil  # type: ignore
16
17 import argparse_utils
18 import bootstrap
19 import config
20 from thread_utils import background_thread
21
22
23 logger = logging.getLogger(__file__)
24
25 cfg = config.add_commandline_args(
26     f"Remote Worker ({__file__})",
27     "Helper to run pickled code remotely and return results",
28 )
29 cfg.add_argument(
30     '--code_file',
31     type=str,
32     required=True,
33     metavar='FILENAME',
34     help='The location of the bundle of code to execute.'
35 )
36 cfg.add_argument(
37     '--result_file',
38     type=str,
39     required=True,
40     metavar='FILENAME',
41     help='The location where we should write the computation results.'
42 )
43 cfg.add_argument(
44     '--watch_for_cancel',
45     action=argparse_utils.ActionNoYes,
46     default=True,
47     help='Should we watch for the cancellation of our parent ssh process?'
48 )
49
50
51 @background_thread
52 def watch_for_cancel(terminate_event: threading.Event) -> None:
53     logger.debug('Starting up background thread...')
54     p = psutil.Process(os.getpid())
55     while True:
56         saw_sshd = False
57         ancestors = p.parents()
58         for ancestor in ancestors:
59             name = ancestor.name()
60             pid = ancestor.pid
61             logger.debug(f'Ancestor process {name} (pid={pid})')
62             if 'ssh' in name.lower():
63                 saw_sshd = True
64                 break
65         if not saw_sshd:
66             logger.error('Did not see sshd in our ancestors list?!  Committing suicide.')
67             os.system('pstree')
68             os.kill(os.getpid(), signal.SIGTERM)
69             time.sleep(5.0)
70             os.kill(os.getpid(), signal.SIGKILL)
71             sys.exit(-1)
72         if terminate_event.is_set():
73             return
74         time.sleep(1.0)
75
76
77 @bootstrap.initialize
78 def main() -> None:
79     in_file = config.config['code_file']
80     out_file = config.config['result_file']
81
82     stop_thread = None
83     if config.config['watch_for_cancel']:
84         (thread, stop_thread) = watch_for_cancel()
85
86     logger.debug(f'Reading {in_file}.')
87     try:
88         with open(in_file, 'rb') as rb:
89             serialized = rb.read()
90     except Exception as e:
91         logger.exception(e)
92         logger.critical(f'Problem reading {in_file}.  Aborting.')
93         stop_thread.set()
94         sys.exit(-1)
95
96     logger.debug(f'Deserializing {in_file}.')
97     try:
98         fun, args, kwargs = cloudpickle.loads(serialized)
99     except Exception as e:
100         logger.exception(e)
101         logger.critical(f'Problem deserializing {in_file}.  Aborting.')
102         stop_thread.set()
103         sys.exit(-1)
104
105     logger.debug('Invoking user code...')
106     start = time.time()
107     ret = fun(*args, **kwargs)
108     end = time.time()
109     logger.debug(f'User code took {end - start:.1f}s')
110
111     logger.debug('Serializing results')
112     try:
113         serialized = cloudpickle.dumps(ret)
114     except Exception as e:
115         logger.exception(e)
116         logger.critical(f'Could not serialize result ({type(ret)}).  Aborting.')
117         stop_thread.set()
118         sys.exit(-1)
119
120     logger.debug(f'Writing {out_file}.')
121     try:
122         with open(out_file, 'wb') as wb:
123             wb.write(serialized)
124     except Exception as e:
125         logger.exception(e)
126         logger.critical(f'Error writing {out_file}.  Aborting.')
127         stop_thread.set()
128         sys.exit(-1)
129
130     if stop_thread is not None:
131         stop_thread.set()
132         thread.join()
133
134
135 if __name__ == '__main__':
136     main()