X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=remote_worker.py;h=12a5028c30e2bf95093542296bb1cf5a9866f879;hb=5317c50ce7a96a37acfab3800c0935580766dbbf;hp=42aeb854ce633b47e1f1be85b40bce8bf8d436f6;hpb=b454ad295eb3024a238d32bf2aef1ebc3c496b44;p=python_utils.git diff --git a/remote_worker.py b/remote_worker.py index 42aeb85..12a5028 100755 --- a/remote_worker.py +++ b/remote_worker.py @@ -6,11 +6,11 @@ results. import logging import os -import platform import signal import threading import sys import time +from typing import Optional import cloudpickle # type: ignore import psutil # type: ignore @@ -18,6 +18,7 @@ import psutil # type: ignore import argparse_utils import bootstrap import config +from stopwatch import Timer from thread_utils import background_thread @@ -32,20 +33,20 @@ cfg.add_argument( type=str, required=True, metavar='FILENAME', - help='The location of the bundle of code to execute.' + help='The location of the bundle of code to execute.', ) cfg.add_argument( '--result_file', type=str, required=True, metavar='FILENAME', - help='The location where we should write the computation results.' + help='The location where we should write the computation results.', ) cfg.add_argument( '--watch_for_cancel', action=argparse_utils.ActionNoYes, - default=False, - help='Should we watch for the cancellation of our parent ssh process?' + default=True, + help='Should we watch for the cancellation of our parent ssh process?', ) @@ -64,7 +65,9 @@ def watch_for_cancel(terminate_event: threading.Event) -> None: saw_sshd = True break if not saw_sshd: - logger.error('Did not see sshd in our ancestors list?! Committing suicide.') + logger.error( + 'Did not see sshd in our ancestors list?! Committing suicide.' + ) os.system('pstree') os.kill(os.getpid(), signal.SIGTERM) time.sleep(5.0) @@ -75,11 +78,24 @@ def watch_for_cancel(terminate_event: threading.Event) -> None: time.sleep(1.0) +def cleanup_and_exit( + thread: Optional[threading.Thread], + stop_thread: Optional[threading.Event], + exit_code: int, +) -> None: + if stop_thread is not None: + stop_thread.set() + assert thread is not None + thread.join() + sys.exit(exit_code) + + @bootstrap.initialize def main() -> None: in_file = config.config['code_file'] out_file = config.config['result_file'] + thread = None stop_thread = None if config.config['watch_for_cancel']: (thread, stop_thread) = watch_for_cancel() @@ -91,8 +107,7 @@ def main() -> None: except Exception as e: logger.exception(e) logger.critical(f'Problem reading {in_file}. Aborting.') - stop_thread.set() - sys.exit(-1) + cleanup_and_exit(thread, stop_thread, 1) logger.debug(f'Deserializing {in_file}.') try: @@ -100,14 +115,12 @@ def main() -> None: except Exception as e: logger.exception(e) logger.critical(f'Problem deserializing {in_file}. Aborting.') - stop_thread.set() - sys.exit(-1) + cleanup_and_exit(thread, stop_thread, 2) logger.debug('Invoking user code...') - start = time.time() - ret = fun(*args, **kwargs) - end = time.time() - logger.debug(f'User code took {end - start:.1f}s') + with Timer() as t: + ret = fun(*args, **kwargs) + logger.debug(f'User code took {t():.1f}s') logger.debug('Serializing results') try: @@ -115,8 +128,7 @@ def main() -> None: except Exception as e: logger.exception(e) logger.critical(f'Could not serialize result ({type(ret)}). Aborting.') - stop_thread.set() - sys.exit(-1) + cleanup_and_exit(thread, stop_thread, 3) logger.debug(f'Writing {out_file}.') try: @@ -125,12 +137,8 @@ def main() -> None: except Exception as e: logger.exception(e) logger.critical(f'Error writing {out_file}. Aborting.') - stop_thread.set() - sys.exit(-1) - - if stop_thread is not None: - stop_thread.set() - thread.join() + cleanup_and_exit(thread, stop_thread, 4) + cleanup_and_exit(thread, stop_thread, 0) if __name__ == '__main__':