Various changes.
[python_utils.git] / remote_worker.py
1 #!/usr/bin/env python3
2
3 """A simple utility to unpickle some code, run it, and pickle the
4 results.
5 """
6
7 import os
8 import platform
9 import signal
10 import sys
11 import threading
12 import time
13
14 import cloudpickle  # type: ignore
15 import psutil  # type: ignore
16
17 import argparse_utils
18 import bootstrap
19 import config
20 from thread_utils import background_thread
21
22
23 cfg = config.add_commandline_args(
24     f"Remote Worker ({__file__})",
25     "Helper to run pickled code remotely and return results",
26 )
27 cfg.add_argument(
28     '--code_file',
29     type=str,
30     required=True,
31     metavar='FILENAME',
32     help='The location of the bundle of code to execute.'
33 )
34 cfg.add_argument(
35     '--result_file',
36     type=str,
37     required=True,
38     metavar='FILENAME',
39     help='The location where we should write the computation results.'
40 )
41 cfg.add_argument(
42     '--watch_for_cancel',
43     action=argparse_utils.ActionNoYes,
44     default=False,
45     help='Should we watch for the cancellation of our parent ssh process?'
46 )
47
48
49 @background_thread
50 def watch_for_cancel(terminate_event: threading.Event) -> None:
51     p = psutil.Process(os.getpid())
52     while True:
53         saw_sshd = False
54         ancestors = p.parents()
55         for ancestor in ancestors:
56             name = ancestor.name()
57             if 'ssh' in name or 'Ssh' in name:
58                 saw_sshd = True
59                 break
60         if not saw_sshd:
61             os.system('pstree')
62             os.kill(os.getpid(), signal.SIGTERM)
63         if terminate_event.is_set():
64             return
65         time.sleep(1.0)
66
67
68 if __name__ == '__main__':
69     @bootstrap.initialize
70     def main() -> None:
71         hostname = platform.node()
72
73         # Windows-Linux is retarded.
74     #    if (
75     #            hostname != 'VIDEO-COMPUTER' and
76     #            config.config['watch_for_cancel']
77     #    ):
78     #        (thread, terminate_event) = watch_for_cancel()
79
80         in_file = config.config['code_file']
81         out_file = config.config['result_file']
82
83         with open(in_file, 'rb') as rb:
84             serialized = rb.read()
85
86         fun, args, kwargs = cloudpickle.loads(serialized)
87         print(fun)
88         print(args)
89         print(kwargs)
90         print("Invoking the code...")
91         ret = fun(*args, **kwargs)
92
93         serialized = cloudpickle.dumps(ret)
94         with open(out_file, 'wb') as wb:
95             wb.write(serialized)
96
97         # Windows-Linux is retarded.
98     #    if hostname != 'VIDEO-COMPUTER':
99     #        terminate_event.set()
100     #        thread.join()
101         sys.exit(0)
102     main()