eefc40a299fd87d14ea54e86bbbd4d21cd5dc8c0
[pyutils.git] / src / pyutils / remote_worker.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A simple utility to unpickle some code from the filesystem, run it,
6 pickle the results, and save them back on the filesystem.  This file
7 helps :mod:`pyutils.parallelize.parallelize` and
8 :mod:`pyutils.parallelize.executors` implement the
9 :class:`pyutils.parallelize.executors.RemoteExecutor` that distributes
10 work to different machines when code is marked with the
11 `@parallelize(method=Method.REMOTE)` decorator.
12
13 .. warning::
14     Please don't unpickle (or run!) code you do not know!  This
15     helper is designed to be run with your own code.
16
17 See details in :mod:`pyutils.parallelize.parallelize` for instructions
18 about how to set this up.
19 """
20
21 import logging
22 import os
23 import signal
24 import sys
25 import threading
26 import time
27 from typing import Optional
28
29 import cloudpickle  # type: ignore
30 import psutil  # type: ignore
31
32 from pyutils import argparse_utils, bootstrap, config
33 from pyutils.parallelize.thread_utils import background_thread
34 from pyutils.stopwatch import Timer
35
36 logger = logging.getLogger(__file__)
37
38 cfg = config.add_commandline_args(
39     f"Remote Worker ({__file__})",
40     "Helper to run pickled code remotely and return results",
41 )
42 cfg.add_argument(
43     "--code_file",
44     type=str,
45     required=True,
46     metavar="FILENAME",
47     help="The location of the bundle of code to execute.",
48 )
49 cfg.add_argument(
50     "--result_file",
51     type=str,
52     required=True,
53     metavar="FILENAME",
54     help="The location where we should write the computation results.",
55 )
56 cfg.add_argument(
57     "--watch_for_cancel",
58     action=argparse_utils.ActionNoYes,
59     default=True,
60     help="Should we watch for the cancellation of our parent ssh process?",
61 )
62
63
64 @background_thread
65 def _watch_for_cancel(terminate_event: threading.Event) -> None:
66     logger.debug("Starting up background thread...")
67     p = psutil.Process(os.getpid())
68     while True:
69         saw_sshd = False
70         ancestors = p.parents()
71         for ancestor in ancestors:
72             name = ancestor.name()
73             pid = ancestor.pid
74             logger.debug("Ancestor process %s (pid=%d)", name, pid)
75             if "ssh" in name.lower():
76                 saw_sshd = True
77                 break
78         if not saw_sshd:
79             logger.error(
80                 "Did not see sshd in our ancestors list?!  Committing suicide."
81             )
82             os.system("pstree")
83             os.kill(os.getpid(), signal.SIGTERM)
84             time.sleep(5.0)
85             os.kill(os.getpid(), signal.SIGKILL)
86             sys.exit(-1)
87         if terminate_event.is_set():
88             return
89         time.sleep(1.0)
90
91
92 def _cleanup_and_exit(
93     thread: Optional[threading.Thread],
94     stop_event: Optional[threading.Event],
95     exit_code: int,
96 ) -> None:
97     if stop_event is not None:
98         stop_event.set()
99         assert thread is not None
100         thread.join()
101     sys.exit(exit_code)
102
103
104 @bootstrap.initialize
105 def main() -> None:
106     """Remote worker entry point."""
107
108     in_file = config.config["code_file"]
109     assert in_file and type(in_file) == str
110     out_file = config.config["result_file"]
111     assert out_file and type(out_file) == str
112
113     thread = None
114     stop_event = None
115     if config.config["watch_for_cancel"]:
116         thread, stop_event = _watch_for_cancel()
117
118     logger.debug("Reading %s.", in_file)
119     try:
120         with open(in_file, "rb") as rb:
121             serialized = rb.read()
122     except Exception:
123         logger.exception("Problem reading %s; aborting.", in_file)
124         _cleanup_and_exit(thread, stop_event, 1)
125
126     logger.debug("Deserializing %s", in_file)
127     try:
128         fun, args, kwargs = cloudpickle.loads(serialized)
129     except Exception:
130         logger.exception("Problem deserializing %s. Aborting.", in_file)
131         _cleanup_and_exit(thread, stop_event, 2)
132
133     logger.debug("Invoking user-defined code...")
134     with Timer() as t:
135         ret = fun(*args, **kwargs)
136     logger.debug("User code took %.1fs", t())
137
138     logger.debug("Serializing results")
139     try:
140         serialized = cloudpickle.dumps(ret)
141     except Exception:
142         logger.exception("Could not serialize result (%s). Aborting.", type(ret))
143         _cleanup_and_exit(thread, stop_event, 3)
144
145     logger.debug("Writing %s", out_file)
146     try:
147         with open(out_file, "wb") as wb:
148             wb.write(serialized)
149     except Exception:
150         logger.exception("Error writing %s. Aborting.", out_file)
151         _cleanup_and_exit(thread, stop_event, 4)
152     _cleanup_and_exit(thread, stop_event, 0)
153
154
155 if __name__ == "__main__":
156     main()