630d7e035fd414b166a1dda358ffd6aa430a70cd
[pyutils.git] / src / pyutils / remote_worker.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A simple utility to unpickle some code, run it, and pickle the
6 results.  Please don't unpickle (or run!) code you do not know.
7
8 This script is used by code in parallelize, namely the
9 :class:`RemoteExecutor`, to schedule work on a remote machine.
10 The code in :file:`parallelize.py` uses a user-defined configuration
11 to schedule work this way.  See that file for setup instructions.
12 """
13
14 import logging
15 import os
16 import signal
17 import sys
18 import threading
19 import time
20 from typing import Optional
21
22 import cloudpickle  # type: ignore
23 import psutil  # type: ignore
24
25 from pyutils import argparse_utils, bootstrap, config
26 from pyutils.parallelize.thread_utils import background_thread
27 from pyutils.stopwatch import Timer
28
29 logger = logging.getLogger(__file__)
30
31 cfg = config.add_commandline_args(
32     f"Remote Worker ({__file__})",
33     "Helper to run pickled code remotely and return results",
34 )
35 cfg.add_argument(
36     '--code_file',
37     type=str,
38     required=True,
39     metavar='FILENAME',
40     help='The location of the bundle of code to execute.',
41 )
42 cfg.add_argument(
43     '--result_file',
44     type=str,
45     required=True,
46     metavar='FILENAME',
47     help='The location where we should write the computation results.',
48 )
49 cfg.add_argument(
50     '--watch_for_cancel',
51     action=argparse_utils.ActionNoYes,
52     default=True,
53     help='Should we watch for the cancellation of our parent ssh process?',
54 )
55
56
57 @background_thread
58 def watch_for_cancel(terminate_event: threading.Event) -> None:
59     logger.debug('Starting up background thread...')
60     p = psutil.Process(os.getpid())
61     while True:
62         saw_sshd = False
63         ancestors = p.parents()
64         for ancestor in ancestors:
65             name = ancestor.name()
66             pid = ancestor.pid
67             logger.debug('Ancestor process %s (pid=%d)', name, pid)
68             if 'ssh' in name.lower():
69                 saw_sshd = True
70                 break
71         if not saw_sshd:
72             logger.error(
73                 'Did not see sshd in our ancestors list?!  Committing suicide.'
74             )
75             os.system('pstree')
76             os.kill(os.getpid(), signal.SIGTERM)
77             time.sleep(5.0)
78             os.kill(os.getpid(), signal.SIGKILL)
79             sys.exit(-1)
80         if terminate_event.is_set():
81             return
82         time.sleep(1.0)
83
84
85 def cleanup_and_exit(
86     thread: Optional[threading.Thread],
87     stop_thread: Optional[threading.Event],
88     exit_code: int,
89 ) -> None:
90     if stop_thread is not None:
91         stop_thread.set()
92         assert thread is not None
93         thread.join()
94     sys.exit(exit_code)
95
96
97 @bootstrap.initialize
98 def main() -> None:
99     in_file = config.config['code_file']
100     assert in_file and type(in_file) == str
101     out_file = config.config['result_file']
102     assert out_file and type(out_file) == str
103
104     thread = None
105     stop_thread = None
106     if config.config['watch_for_cancel']:
107         (thread, stop_thread) = watch_for_cancel()
108
109     logger.debug('Reading %s.', in_file)
110     try:
111         with open(in_file, 'rb') as rb:
112             serialized = rb.read()
113     except Exception as e:
114         logger.exception(e)
115         logger.critical('Problem reading %s. Aborting.', in_file)
116         cleanup_and_exit(thread, stop_thread, 1)
117
118     logger.debug('Deserializing %s', in_file)
119     try:
120         fun, args, kwargs = cloudpickle.loads(serialized)
121     except Exception as e:
122         logger.exception(e)
123         logger.critical('Problem deserializing %s. Aborting.', in_file)
124         cleanup_and_exit(thread, stop_thread, 2)
125
126     logger.debug('Invoking user code...')
127     with Timer() as t:
128         ret = fun(*args, **kwargs)
129     logger.debug('User code took %.1fs', t())
130
131     logger.debug('Serializing results')
132     try:
133         serialized = cloudpickle.dumps(ret)
134     except Exception as e:
135         logger.exception(e)
136         logger.critical('Could not serialize result (%s). Aborting.', type(ret))
137         cleanup_and_exit(thread, stop_thread, 3)
138
139     logger.debug('Writing %s', out_file)
140     try:
141         with open(out_file, 'wb') as wb:
142             wb.write(serialized)
143     except Exception as e:
144         logger.exception(e)
145         logger.critical('Error writing %s. Aborting.', out_file)
146         cleanup_and_exit(thread, stop_thread, 4)
147     cleanup_and_exit(thread, stop_thread, 0)
148
149
150 if __name__ == '__main__':
151     main()