#!/usr/bin/env python3 """A simple utility to unpickle some code, run it, and pickle the results. """ import logging import os import signal import threading import sys import time import cloudpickle # type: ignore import psutil # type: ignore import argparse_utils import bootstrap import config from thread_utils import background_thread logger = logging.getLogger(__file__) cfg = config.add_commandline_args( f"Remote Worker ({__file__})", "Helper to run pickled code remotely and return results", ) cfg.add_argument( '--code_file', type=str, required=True, metavar='FILENAME', help='The location of the bundle of code to execute.', ) cfg.add_argument( '--result_file', type=str, required=True, metavar='FILENAME', help='The location where we should write the computation results.', ) cfg.add_argument( '--watch_for_cancel', action=argparse_utils.ActionNoYes, default=True, help='Should we watch for the cancellation of our parent ssh process?', ) @background_thread def watch_for_cancel(terminate_event: threading.Event) -> None: logger.debug('Starting up background thread...') p = psutil.Process(os.getpid()) while True: saw_sshd = False ancestors = p.parents() for ancestor in ancestors: name = ancestor.name() pid = ancestor.pid logger.debug(f'Ancestor process {name} (pid={pid})') if 'ssh' in name.lower(): saw_sshd = True break if not saw_sshd: logger.error( 'Did not see sshd in our ancestors list?! Committing suicide.' ) os.system('pstree') os.kill(os.getpid(), signal.SIGTERM) time.sleep(5.0) os.kill(os.getpid(), signal.SIGKILL) sys.exit(-1) if terminate_event.is_set(): return time.sleep(1.0) @bootstrap.initialize def main() -> None: in_file = config.config['code_file'] out_file = config.config['result_file'] stop_thread = None if config.config['watch_for_cancel']: (thread, stop_thread) = watch_for_cancel() logger.debug(f'Reading {in_file}.') try: with open(in_file, 'rb') as rb: serialized = rb.read() except Exception as e: logger.exception(e) logger.critical(f'Problem reading {in_file}. Aborting.') stop_thread.set() sys.exit(-1) logger.debug(f'Deserializing {in_file}.') try: fun, args, kwargs = cloudpickle.loads(serialized) except Exception as e: logger.exception(e) logger.critical(f'Problem deserializing {in_file}. Aborting.') stop_thread.set() sys.exit(-1) logger.debug('Invoking user code...') start = time.time() ret = fun(*args, **kwargs) end = time.time() logger.debug(f'User code took {end - start:.1f}s') logger.debug('Serializing results') try: serialized = cloudpickle.dumps(ret) except Exception as e: logger.exception(e) logger.critical(f'Could not serialize result ({type(ret)}). Aborting.') stop_thread.set() sys.exit(-1) logger.debug(f'Writing {out_file}.') try: with open(out_file, 'wb') as wb: wb.write(serialized) except Exception as e: logger.exception(e) logger.critical(f'Error writing {out_file}. Aborting.') stop_thread.set() sys.exit(-1) if stop_thread is not None: stop_thread.set() thread.join() if __name__ == '__main__': main()