A bunch of changes...
[python_utils.git] / remote_worker.py
1 #!/usr/bin/env python3
2
3 """A simple utility to unpickle some code, run it, and pickle the
4 results.
5 """
6
7 import logging
8 import os
9 import platform
10 import signal
11 import threading
12 import sys
13 import time
14
15 import cloudpickle  # type: ignore
16 import psutil  # type: ignore
17
18 import argparse_utils
19 import bootstrap
20 import config
21 from thread_utils import background_thread
22
23
24 logger = logging.getLogger(__file__)
25
26 cfg = config.add_commandline_args(
27     f"Remote Worker ({__file__})",
28     "Helper to run pickled code remotely and return results",
29 )
30 cfg.add_argument(
31     '--code_file',
32     type=str,
33     required=True,
34     metavar='FILENAME',
35     help='The location of the bundle of code to execute.'
36 )
37 cfg.add_argument(
38     '--result_file',
39     type=str,
40     required=True,
41     metavar='FILENAME',
42     help='The location where we should write the computation results.'
43 )
44 cfg.add_argument(
45     '--watch_for_cancel',
46     action=argparse_utils.ActionNoYes,
47     default=False,
48     help='Should we watch for the cancellation of our parent ssh process?'
49 )
50
51
52 @background_thread
53 def watch_for_cancel(terminate_event: threading.Event) -> None:
54     if platform.node() == 'VIDEO-COMPUTER':
55         logger.warning('Background thread not allowed on retarded computers, sorry.')
56         return
57     logger.debug('Starting up background thread...')
58     p = psutil.Process(os.getpid())
59     while True:
60         saw_sshd = False
61         ancestors = p.parents()
62         for ancestor in ancestors:
63             name = ancestor.name()
64             pid = ancestor.pid
65             logger.debug(f'Ancestor process {name} (pid={pid})')
66             if 'ssh' in name.lower():
67                 saw_sshd = True
68                 break
69         if not saw_sshd:
70             logger.error('Did not see sshd in our ancestors list?!  Committing suicide.')
71             os.system('pstree')
72             os.kill(os.getpid(), signal.SIGTERM)
73             time.sleep(5.0)
74             os.kill(os.getpid(), signal.SIGKILL)
75             sys.exit(-1)
76         if terminate_event.is_set():
77             return
78         time.sleep(1.0)
79
80
81 @bootstrap.initialize
82 def main() -> None:
83     in_file = config.config['code_file']
84     out_file = config.config['result_file']
85
86     stop_thread = None
87     if config.config['watch_for_cancel']:
88         (thread, stop_thread) = watch_for_cancel()
89
90     logger.debug(f'Reading {in_file}.')
91     try:
92         with open(in_file, 'rb') as rb:
93             serialized = rb.read()
94     except Exception as e:
95         logger.exception(e)
96         logger.critical(f'Problem reading {in_file}.  Aborting.')
97         stop_thread.set()
98         sys.exit(-1)
99
100     logger.debug(f'Deserializing {in_file}.')
101     try:
102         fun, args, kwargs = cloudpickle.loads(serialized)
103     except Exception as e:
104         logger.exception(e)
105         logger.critical(f'Problem deserializing {in_file}.  Aborting.')
106         stop_thread.set()
107         sys.exit(-1)
108
109     logger.debug('Invoking user code...')
110     start = time.time()
111     ret = fun(*args, **kwargs)
112     end = time.time()
113     logger.debug(f'User code took {end - start:.1f}s')
114
115     logger.debug('Serializing results')
116     try:
117         serialized = cloudpickle.dumps(ret)
118     except Exception as e:
119         logger.exception(e)
120         logger.critical(f'Could not serialize result ({type(ret)}).  Aborting.')
121         stop_thread.set()
122         sys.exit(-1)
123
124     logger.debug(f'Writing {out_file}.')
125     try:
126         with open(out_file, 'wb') as wb:
127             wb.write(serialized)
128     except Exception as e:
129         logger.exception(e)
130         logger.critical(f'Error writing {out_file}.  Aborting.')
131         stop_thread.set()
132         sys.exit(-1)
133
134     if stop_thread is not None:
135         stop_thread.set()
136         thread.join()
137
138
139 if __name__ == '__main__':
140     main()