Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / remote_worker.py
index 43b841589c670b758d52c777e716835b32863c51..8bc254070c7ec030a967efb8938e42639eaa5231 100755 (executable)
@@ -1,15 +1,18 @@
 #!/usr/bin/env python3
 
+# © Copyright 2021-2022, Scott Gasch
+
 """A simple utility to unpickle some code, run it, and pickle the
 results.
 """
 
+import logging
 import os
-import platform
 import signal
 import sys
 import threading
 import time
+from typing import Optional
 
 import cloudpickle  # type: ignore
 import psutil  # type: ignore
@@ -17,8 +20,10 @@ import psutil  # type: ignore
 import argparse_utils
 import bootstrap
 import config
+from stopwatch import Timer
 from thread_utils import background_thread
 
+logger = logging.getLogger(__file__)
 
 cfg = config.add_commandline_args(
     f"Remote Worker ({__file__})",
@@ -29,74 +34,111 @@ cfg.add_argument(
     type=str,
     required=True,
     metavar='FILENAME',
-    help='The location of the bundle of code to execute.'
+    help='The location of the bundle of code to execute.',
 )
 cfg.add_argument(
     '--result_file',
     type=str,
     required=True,
     metavar='FILENAME',
-    help='The location where we should write the computation results.'
+    help='The location where we should write the computation results.',
 )
 cfg.add_argument(
     '--watch_for_cancel',
     action=argparse_utils.ActionNoYes,
-    default=False,
-    help='Should we watch for the cancellation of our parent ssh process?'
+    default=True,
+    help='Should we watch for the cancellation of our parent ssh process?',
 )
 
 
 @background_thread
 def watch_for_cancel(terminate_event: threading.Event) -> None:
+    logger.debug('Starting up background thread...')
     p = psutil.Process(os.getpid())
     while True:
         saw_sshd = False
         ancestors = p.parents()
         for ancestor in ancestors:
             name = ancestor.name()
-            if 'ssh' in name or 'Ssh' in name:
+            pid = ancestor.pid
+            logger.debug('Ancestor process %s (pid=%d)', name, pid)
+            if 'ssh' in name.lower():
                 saw_sshd = True
                 break
         if not saw_sshd:
+            logger.error('Did not see sshd in our ancestors list?!  Committing suicide.')
             os.system('pstree')
             os.kill(os.getpid(), signal.SIGTERM)
+            time.sleep(5.0)
+            os.kill(os.getpid(), signal.SIGKILL)
+            sys.exit(-1)
         if terminate_event.is_set():
             return
         time.sleep(1.0)
 
 
-if __name__ == '__main__':
-    @bootstrap.initialize
-    def main() -> None:
-        hostname = platform.node()
+def cleanup_and_exit(
+    thread: Optional[threading.Thread],
+    stop_thread: Optional[threading.Event],
+    exit_code: int,
+) -> None:
+    if stop_thread is not None:
+        stop_thread.set()
+        assert thread is not None
+        thread.join()
+    sys.exit(exit_code)
 
-        # Windows-Linux is retarded.
-    #    if (
-    #            hostname != 'VIDEO-COMPUTER' and
-    #            config.config['watch_for_cancel']
-    #    ):
-    #        (thread, terminate_event) = watch_for_cancel()
 
-        in_file = config.config['code_file']
-        out_file = config.config['result_file']
+def main() -> None:
+    in_file = config.config['code_file']
+    out_file = config.config['result_file']
 
+    thread = None
+    stop_thread = None
+    if config.config['watch_for_cancel']:
+        (thread, stop_thread) = watch_for_cancel()
+
+    logger.debug('Reading %s.', in_file)
+    try:
         with open(in_file, 'rb') as rb:
             serialized = rb.read()
+    except Exception as e:
+        logger.exception(e)
+        logger.critical('Problem reading %s. Aborting.', in_file)
+        cleanup_and_exit(thread, stop_thread, 1)
 
+    logger.debug('Deserializing %s', in_file)
+    try:
         fun, args, kwargs = cloudpickle.loads(serialized)
-        print(fun)
-        print(args)
-        print(kwargs)
-        print("Invoking the code...")
+    except Exception as e:
+        logger.exception(e)
+        logger.critical('Problem deserializing %s. Aborting.', in_file)
+        cleanup_and_exit(thread, stop_thread, 2)
+
+    logger.debug('Invoking user code...')
+    with Timer() as t:
         ret = fun(*args, **kwargs)
+    logger.debug('User code took %.1fs', t())
 
+    logger.debug('Serializing results')
+    try:
         serialized = cloudpickle.dumps(ret)
+    except Exception as e:
+        logger.exception(e)
+        logger.critical('Could not serialize result (%s). Aborting.', type(ret))
+        cleanup_and_exit(thread, stop_thread, 3)
+
+    logger.debug('Writing %s', out_file)
+    try:
         with open(out_file, 'wb') as wb:
             wb.write(serialized)
+    except Exception as e:
+        logger.exception(e)
+        logger.critical('Error writing %s. Aborting.', out_file)
+        cleanup_and_exit(thread, stop_thread, 4)
+    cleanup_and_exit(thread, stop_thread, 0)
+
 
-        # Windows-Linux is retarded.
-    #    if hostname != 'VIDEO-COMPUTER':
-    #        terminate_event.set()
-    #        thread.join()
-        sys.exit(0)
+if __name__ == '__main__':
     main()