Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / remote_worker.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A simple utility to unpickle some code, run it, and pickle the
6 results.
7 """
8
9 import logging
10 import os
11 import signal
12 import sys
13 import threading
14 import time
15 from typing import Optional
16
17 import cloudpickle  # type: ignore
18 import psutil  # type: ignore
19
20 import argparse_utils
21 import bootstrap
22 import config
23 from stopwatch import Timer
24 from thread_utils import background_thread
25
26 logger = logging.getLogger(__file__)
27
28 cfg = config.add_commandline_args(
29     f"Remote Worker ({__file__})",
30     "Helper to run pickled code remotely and return results",
31 )
32 cfg.add_argument(
33     '--code_file',
34     type=str,
35     required=True,
36     metavar='FILENAME',
37     help='The location of the bundle of code to execute.',
38 )
39 cfg.add_argument(
40     '--result_file',
41     type=str,
42     required=True,
43     metavar='FILENAME',
44     help='The location where we should write the computation results.',
45 )
46 cfg.add_argument(
47     '--watch_for_cancel',
48     action=argparse_utils.ActionNoYes,
49     default=True,
50     help='Should we watch for the cancellation of our parent ssh process?',
51 )
52
53
54 @background_thread
55 def watch_for_cancel(terminate_event: threading.Event) -> None:
56     logger.debug('Starting up background thread...')
57     p = psutil.Process(os.getpid())
58     while True:
59         saw_sshd = False
60         ancestors = p.parents()
61         for ancestor in ancestors:
62             name = ancestor.name()
63             pid = ancestor.pid
64             logger.debug('Ancestor process %s (pid=%d)', name, pid)
65             if 'ssh' in name.lower():
66                 saw_sshd = True
67                 break
68         if not saw_sshd:
69             logger.error('Did not see sshd in our ancestors list?!  Committing suicide.')
70             os.system('pstree')
71             os.kill(os.getpid(), signal.SIGTERM)
72             time.sleep(5.0)
73             os.kill(os.getpid(), signal.SIGKILL)
74             sys.exit(-1)
75         if terminate_event.is_set():
76             return
77         time.sleep(1.0)
78
79
80 def cleanup_and_exit(
81     thread: Optional[threading.Thread],
82     stop_thread: Optional[threading.Event],
83     exit_code: int,
84 ) -> None:
85     if stop_thread is not None:
86         stop_thread.set()
87         assert thread is not None
88         thread.join()
89     sys.exit(exit_code)
90
91
92 @bootstrap.initialize
93 def main() -> None:
94     in_file = config.config['code_file']
95     out_file = config.config['result_file']
96
97     thread = None
98     stop_thread = None
99     if config.config['watch_for_cancel']:
100         (thread, stop_thread) = watch_for_cancel()
101
102     logger.debug('Reading %s.', in_file)
103     try:
104         with open(in_file, 'rb') as rb:
105             serialized = rb.read()
106     except Exception as e:
107         logger.exception(e)
108         logger.critical('Problem reading %s. Aborting.', in_file)
109         cleanup_and_exit(thread, stop_thread, 1)
110
111     logger.debug('Deserializing %s', in_file)
112     try:
113         fun, args, kwargs = cloudpickle.loads(serialized)
114     except Exception as e:
115         logger.exception(e)
116         logger.critical('Problem deserializing %s. Aborting.', in_file)
117         cleanup_and_exit(thread, stop_thread, 2)
118
119     logger.debug('Invoking user code...')
120     with Timer() as t:
121         ret = fun(*args, **kwargs)
122     logger.debug('User code took %.1fs', t())
123
124     logger.debug('Serializing results')
125     try:
126         serialized = cloudpickle.dumps(ret)
127     except Exception as e:
128         logger.exception(e)
129         logger.critical('Could not serialize result (%s). Aborting.', type(ret))
130         cleanup_and_exit(thread, stop_thread, 3)
131
132     logger.debug('Writing %s', out_file)
133     try:
134         with open(out_file, 'wb') as wb:
135             wb.write(serialized)
136     except Exception as e:
137         logger.exception(e)
138         logger.critical('Error writing %s. Aborting.', out_file)
139         cleanup_and_exit(thread, stop_thread, 4)
140     cleanup_and_exit(thread, stop_thread, 0)
141
142
143 if __name__ == '__main__':
144     main()