X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=remote_worker.py;h=75dfe8e46cb3d34e28f44267129ff38020d23a7a;hb=e7822aa364fcc392476ded5537948292f7db2300;hp=ebd510040d15ac377165281a75c20c8ce63a8474;hpb=497fb9e21f45ec08e1486abaee6dfa7b20b8a691;p=python_utils.git diff --git a/remote_worker.py b/remote_worker.py index ebd5100..75dfe8e 100755 --- a/remote_worker.py +++ b/remote_worker.py @@ -4,20 +4,24 @@ results. """ +import logging import os -import platform import signal import sys import threading import time +from typing import Optional import cloudpickle # type: ignore import psutil # type: ignore +import argparse_utils import bootstrap import config +from stopwatch import Timer from thread_utils import background_thread +logger = logging.getLogger(__file__) cfg = config.add_commandline_args( f"Remote Worker ({__file__})", @@ -28,63 +32,110 @@ cfg.add_argument( type=str, required=True, metavar='FILENAME', - help='The location of the bundle of code to execute.' + help='The location of the bundle of code to execute.', ) cfg.add_argument( '--result_file', type=str, required=True, metavar='FILENAME', - help='The location where we should write the computation results.' + help='The location where we should write the computation results.', +) +cfg.add_argument( + '--watch_for_cancel', + action=argparse_utils.ActionNoYes, + default=True, + help='Should we watch for the cancellation of our parent ssh process?', ) @background_thread def watch_for_cancel(terminate_event: threading.Event) -> None: + logger.debug('Starting up background thread...') p = psutil.Process(os.getpid()) while True: saw_sshd = False ancestors = p.parents() for ancestor in ancestors: name = ancestor.name() - if 'ssh' in name or 'Ssh' in name: + pid = ancestor.pid + logger.debug('Ancestor process %s (pid=%d)', name, pid) + if 'ssh' in name.lower(): saw_sshd = True break - if not saw_sshd: + logger.error('Did not see sshd in our ancestors list?! Committing suicide.') os.system('pstree') os.kill(os.getpid(), signal.SIGTERM) + time.sleep(5.0) + os.kill(os.getpid(), signal.SIGKILL) + sys.exit(-1) if terminate_event.is_set(): return time.sleep(1.0) -@bootstrap.initialize -def main() -> None: - hostname = platform.node() +def cleanup_and_exit( + thread: Optional[threading.Thread], + stop_thread: Optional[threading.Event], + exit_code: int, +) -> None: + if stop_thread is not None: + stop_thread.set() + assert thread is not None + thread.join() + sys.exit(exit_code) - # Windows-Linux is retarded. - if hostname != 'VIDEO-COMPUTER': - (thread, terminate_event) = watch_for_cancel() +@bootstrap.initialize +def main() -> None: in_file = config.config['code_file'] out_file = config.config['result_file'] - with open(in_file, 'rb') as rb: - serialized = rb.read() - - fun, args, kwargs = cloudpickle.loads(serialized) - ret = fun(*args, **kwargs) - - serialized = cloudpickle.dumps(ret) - with open(out_file, 'wb') as wb: - wb.write(serialized) - - # Windows-Linux is retarded. - if hostname != 'VIDEO-COMPUTER': - terminate_event.set() - thread.join() - sys.exit(0) + thread = None + stop_thread = None + if config.config['watch_for_cancel']: + (thread, stop_thread) = watch_for_cancel() + + logger.debug('Reading %s.', in_file) + try: + with open(in_file, 'rb') as rb: + serialized = rb.read() + except Exception as e: + logger.exception(e) + logger.critical('Problem reading %s. Aborting.', in_file) + cleanup_and_exit(thread, stop_thread, 1) + + logger.debug('Deserializing %s', in_file) + try: + fun, args, kwargs = cloudpickle.loads(serialized) + except Exception as e: + logger.exception(e) + logger.critical('Problem deserializing %s. Aborting.', in_file) + cleanup_and_exit(thread, stop_thread, 2) + + logger.debug('Invoking user code...') + with Timer() as t: + ret = fun(*args, **kwargs) + logger.debug('User code took %.1fs', t()) + + logger.debug('Serializing results') + try: + serialized = cloudpickle.dumps(ret) + except Exception as e: + logger.exception(e) + logger.critical('Could not serialize result (%s). Aborting.', type(ret)) + cleanup_and_exit(thread, stop_thread, 3) + + logger.debug('Writing %s', out_file) + try: + with open(out_file, 'wb') as wb: + wb.write(serialized) + except Exception as e: + logger.exception(e) + logger.critical('Error writing %s. Aborting.', out_file) + cleanup_and_exit(thread, stop_thread, 4) + cleanup_and_exit(thread, stop_thread, 0) if __name__ == '__main__':