More messing with the project file.
[pyutils.git] / src / pyutils / remote_worker.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A simple utility to unpickle some code from the filesystem, run it,
6 pickle the results, and save them back on the filesystem.  This file
7 helps :mod:`pyutils.parallelize.parallelize` and
8 :mod:`pyutils.parallelize.executors` implement the
9 :class:`pyutils.parallelize.executors.RemoteExecutor` that distributes
10 work to different machines when code is marked with the
11 `@parallelize(method=Method.REMOTE)` decorator.
12
13 .. warning::
14     Please don't unpickle (or run!) code you do not know!  This
15     helper is designed to be run with your own code.
16
17 See details in :mod:`pyutils.parallelize.parallelize` for instructions
18 about how to set this up.
19 """
20
21 import logging
22 import os
23 import signal
24 import sys
25 import threading
26 import time
27 from typing import Optional
28
29 import cloudpickle  # type: ignore
30 import psutil  # type: ignore
31
32 from pyutils import argparse_utils, bootstrap, config
33 from pyutils.parallelize.thread_utils import background_thread
34 from pyutils.stopwatch import Timer
35
36 logger = logging.getLogger(__file__)
37
38 cfg = config.add_commandline_args(
39     f"Remote Worker ({__file__})",
40     "Helper to run pickled code remotely and return results",
41 )
42 cfg.add_argument(
43     '--code_file',
44     type=str,
45     required=True,
46     metavar='FILENAME',
47     help='The location of the bundle of code to execute.',
48 )
49 cfg.add_argument(
50     '--result_file',
51     type=str,
52     required=True,
53     metavar='FILENAME',
54     help='The location where we should write the computation results.',
55 )
56 cfg.add_argument(
57     '--watch_for_cancel',
58     action=argparse_utils.ActionNoYes,
59     default=True,
60     help='Should we watch for the cancellation of our parent ssh process?',
61 )
62
63
64 @background_thread
65 def _watch_for_cancel(terminate_event: threading.Event) -> None:
66     logger.debug('Starting up background thread...')
67     p = psutil.Process(os.getpid())
68     while True:
69         saw_sshd = False
70         ancestors = p.parents()
71         for ancestor in ancestors:
72             name = ancestor.name()
73             pid = ancestor.pid
74             logger.debug('Ancestor process %s (pid=%d)', name, pid)
75             if 'ssh' in name.lower():
76                 saw_sshd = True
77                 break
78         if not saw_sshd:
79             logger.error(
80                 'Did not see sshd in our ancestors list?!  Committing suicide.'
81             )
82             os.system('pstree')
83             os.kill(os.getpid(), signal.SIGTERM)
84             time.sleep(5.0)
85             os.kill(os.getpid(), signal.SIGKILL)
86             sys.exit(-1)
87         if terminate_event.is_set():
88             return
89         time.sleep(1.0)
90
91
92 def _cleanup_and_exit(
93     thread: Optional[threading.Thread],
94     stop_event: Optional[threading.Event],
95     exit_code: int,
96 ) -> None:
97     if stop_event is not None:
98         stop_event.set()
99         assert thread is not None
100         thread.join()
101     sys.exit(exit_code)
102
103
104 @bootstrap.initialize
105 def main() -> None:
106     """Remote worker entry point."""
107
108     in_file = config.config['code_file']
109     assert in_file and type(in_file) == str
110     out_file = config.config['result_file']
111     assert out_file and type(out_file) == str
112
113     thread = None
114     stop_event = None
115     if config.config['watch_for_cancel']:
116         thread, stop_event = _watch_for_cancel()
117
118     logger.debug('Reading %s.', in_file)
119     try:
120         with open(in_file, 'rb') as rb:
121             serialized = rb.read()
122     except Exception as e:
123         logger.exception(e)
124         logger.critical('Problem reading %s. Aborting.', in_file)
125         _cleanup_and_exit(thread, stop_event, 1)
126
127     logger.debug('Deserializing %s', in_file)
128     try:
129         fun, args, kwargs = cloudpickle.loads(serialized)
130     except Exception as e:
131         logger.exception(e)
132         logger.critical('Problem deserializing %s. Aborting.', in_file)
133         _cleanup_and_exit(thread, stop_event, 2)
134
135     logger.debug('Invoking user-defined code...')
136     with Timer() as t:
137         ret = fun(*args, **kwargs)
138     logger.debug('User code took %.1fs', t())
139
140     logger.debug('Serializing results')
141     try:
142         serialized = cloudpickle.dumps(ret)
143     except Exception as e:
144         logger.exception(e)
145         logger.critical('Could not serialize result (%s). Aborting.', type(ret))
146         _cleanup_and_exit(thread, stop_event, 3)
147
148     logger.debug('Writing %s', out_file)
149     try:
150         with open(out_file, 'wb') as wb:
151             wb.write(serialized)
152     except Exception as e:
153         logger.exception(e)
154         logger.critical('Error writing %s. Aborting.', out_file)
155         _cleanup_and_exit(thread, stop_event, 4)
156     _cleanup_and_exit(thread, stop_event, 0)
157
158
159 if __name__ == '__main__':
160     main()