3 # © Copyright 2021-2022, Scott Gasch
5 """A simple utility to unpickle some code, run it, and pickle the
16 from typing import Optional
18 import cloudpickle # type: ignore
19 import psutil # type: ignore
24 from stopwatch import Timer
25 from thread_utils import background_thread
27 logger = logging.getLogger(__file__)
29 cfg = config.add_commandline_args(
30 f"Remote Worker ({__file__})",
31 "Helper to run pickled code remotely and return results",
38 help='The location of the bundle of code to execute.',
45 help='The location where we should write the computation results.',
49 action=argparse_utils.ActionNoYes,
51 help='Should we watch for the cancellation of our parent ssh process?',
56 def watch_for_cancel(terminate_event: threading.Event) -> None:
57 logger.debug('Starting up background thread...')
58 p = psutil.Process(os.getpid())
61 ancestors = p.parents()
62 for ancestor in ancestors:
63 name = ancestor.name()
65 logger.debug('Ancestor process %s (pid=%d)', name, pid)
66 if 'ssh' in name.lower():
70 logger.error('Did not see sshd in our ancestors list?! Committing suicide.')
72 os.kill(os.getpid(), signal.SIGTERM)
74 os.kill(os.getpid(), signal.SIGKILL)
76 if terminate_event.is_set():
82 thread: Optional[threading.Thread],
83 stop_thread: Optional[threading.Event],
86 if stop_thread is not None:
88 assert thread is not None
95 in_file = config.config['code_file']
96 out_file = config.config['result_file']
100 if config.config['watch_for_cancel']:
101 (thread, stop_thread) = watch_for_cancel()
103 logger.debug('Reading %s.', in_file)
105 with open(in_file, 'rb') as rb:
106 serialized = rb.read()
107 except Exception as e:
109 logger.critical('Problem reading %s. Aborting.', in_file)
110 cleanup_and_exit(thread, stop_thread, 1)
112 logger.debug('Deserializing %s', in_file)
114 fun, args, kwargs = cloudpickle.loads(serialized)
115 except Exception as e:
117 logger.critical('Problem deserializing %s. Aborting.', in_file)
118 cleanup_and_exit(thread, stop_thread, 2)
120 logger.debug('Invoking user code...')
122 ret = fun(*args, **kwargs)
123 logger.debug('User code took %.1fs', t())
125 logger.debug('Serializing results')
127 serialized = cloudpickle.dumps(ret)
128 except Exception as e:
130 logger.critical('Could not serialize result (%s). Aborting.', type(ret))
131 cleanup_and_exit(thread, stop_thread, 3)
133 logger.debug('Writing %s', out_file)
135 with open(out_file, 'wb') as wb:
137 except Exception as e:
139 logger.critical('Error writing %s. Aborting.', out_file)
140 cleanup_and_exit(thread, stop_thread, 4)
141 cleanup_and_exit(thread, stop_thread, 0)
144 if __name__ == '__main__':