Update requirements to include psutil.
[python_utils.git] / remote_worker.py
index 0086c40b0379ce680383c3b4e723bdc92b3bec0a..8bc254070c7ec030a967efb8938e42639eaa5231 100755 (executable)
@@ -1,16 +1,18 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
+# © Copyright 2021-2022, Scott Gasch
+
 """A simple utility to unpickle some code, run it, and pickle the
 results.
 """
 
 import logging
 import os
 """A simple utility to unpickle some code, run it, and pickle the
 results.
 """
 
 import logging
 import os
-import platform
 import signal
 import signal
-import threading
 import sys
 import sys
+import threading
 import time
 import time
+from typing import Optional
 
 import cloudpickle  # type: ignore
 import psutil  # type: ignore
 
 import cloudpickle  # type: ignore
 import psutil  # type: ignore
@@ -18,9 +20,9 @@ import psutil  # type: ignore
 import argparse_utils
 import bootstrap
 import config
 import argparse_utils
 import bootstrap
 import config
+from stopwatch import Timer
 from thread_utils import background_thread
 
 from thread_utils import background_thread
 
-
 logger = logging.getLogger(__file__)
 
 cfg = config.add_commandline_args(
 logger = logging.getLogger(__file__)
 
 cfg = config.add_commandline_args(
@@ -32,28 +34,25 @@ cfg.add_argument(
     type=str,
     required=True,
     metavar='FILENAME',
     type=str,
     required=True,
     metavar='FILENAME',
-    help='The location of the bundle of code to execute.'
+    help='The location of the bundle of code to execute.',
 )
 cfg.add_argument(
     '--result_file',
     type=str,
     required=True,
     metavar='FILENAME',
 )
 cfg.add_argument(
     '--result_file',
     type=str,
     required=True,
     metavar='FILENAME',
-    help='The location where we should write the computation results.'
+    help='The location where we should write the computation results.',
 )
 cfg.add_argument(
     '--watch_for_cancel',
     action=argparse_utils.ActionNoYes,
 )
 cfg.add_argument(
     '--watch_for_cancel',
     action=argparse_utils.ActionNoYes,
-    default=False,
-    help='Should we watch for the cancellation of our parent ssh process?'
+    default=True,
+    help='Should we watch for the cancellation of our parent ssh process?',
 )
 
 
 @background_thread
 def watch_for_cancel(terminate_event: threading.Event) -> None:
 )
 
 
 @background_thread
 def watch_for_cancel(terminate_event: threading.Event) -> None:
-    if platform.node() == 'VIDEO-COMPUTER':
-        logger.warning('Background thread not allowed on retarded computers, sorry.')
-        return
     logger.debug('Starting up background thread...')
     p = psutil.Process(os.getpid())
     while True:
     logger.debug('Starting up background thread...')
     p = psutil.Process(os.getpid())
     while True:
@@ -62,7 +61,7 @@ def watch_for_cancel(terminate_event: threading.Event) -> None:
         for ancestor in ancestors:
             name = ancestor.name()
             pid = ancestor.pid
         for ancestor in ancestors:
             name = ancestor.name()
             pid = ancestor.pid
-            logger.debug(f'Ancestor process {name} (pid={pid})')
+            logger.debug('Ancestor process %s (pid=%d)', name, pid)
             if 'ssh' in name.lower():
                 saw_sshd = True
                 break
             if 'ssh' in name.lower():
                 saw_sshd = True
                 break
@@ -78,62 +77,67 @@ def watch_for_cancel(terminate_event: threading.Event) -> None:
         time.sleep(1.0)
 
 
         time.sleep(1.0)
 
 
+def cleanup_and_exit(
+    thread: Optional[threading.Thread],
+    stop_thread: Optional[threading.Event],
+    exit_code: int,
+) -> None:
+    if stop_thread is not None:
+        stop_thread.set()
+        assert thread is not None
+        thread.join()
+    sys.exit(exit_code)
+
+
 @bootstrap.initialize
 def main() -> None:
     in_file = config.config['code_file']
     out_file = config.config['result_file']
 
 @bootstrap.initialize
 def main() -> None:
     in_file = config.config['code_file']
     out_file = config.config['result_file']
 
+    thread = None
     stop_thread = None
     if config.config['watch_for_cancel']:
         (thread, stop_thread) = watch_for_cancel()
 
     stop_thread = None
     if config.config['watch_for_cancel']:
         (thread, stop_thread) = watch_for_cancel()
 
-    logger.debug(f'Reading {in_file}.')
+    logger.debug('Reading %s.', in_file)
     try:
         with open(in_file, 'rb') as rb:
             serialized = rb.read()
     except Exception as e:
         logger.exception(e)
     try:
         with open(in_file, 'rb') as rb:
             serialized = rb.read()
     except Exception as e:
         logger.exception(e)
-        logger.critical(f'Problem reading {in_file}.  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
+        logger.critical('Problem reading %s. Aborting.', in_file)
+        cleanup_and_exit(thread, stop_thread, 1)
 
 
-    logger.debug(f'Deserializing {in_file}.')
+    logger.debug('Deserializing %s', in_file)
     try:
         fun, args, kwargs = cloudpickle.loads(serialized)
     except Exception as e:
         logger.exception(e)
     try:
         fun, args, kwargs = cloudpickle.loads(serialized)
     except Exception as e:
         logger.exception(e)
-        logger.critical(f'Problem deserializing {in_file}.  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
+        logger.critical('Problem deserializing %s. Aborting.', in_file)
+        cleanup_and_exit(thread, stop_thread, 2)
 
     logger.debug('Invoking user code...')
 
     logger.debug('Invoking user code...')
-    start = time.time()
-    ret = fun(*args, **kwargs)
-    end = time.time()
-    logger.debug(f'User code took {end - start:.1f}s')
+    with Timer() as t:
+        ret = fun(*args, **kwargs)
+    logger.debug('User code took %.1fs', t())
 
     logger.debug('Serializing results')
     try:
         serialized = cloudpickle.dumps(ret)
     except Exception as e:
         logger.exception(e)
 
     logger.debug('Serializing results')
     try:
         serialized = cloudpickle.dumps(ret)
     except Exception as e:
         logger.exception(e)
-        logger.critical(f'Could not serialize result ({type(ret)}).  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
+        logger.critical('Could not serialize result (%s). Aborting.', type(ret))
+        cleanup_and_exit(thread, stop_thread, 3)
 
 
-    logger.debug(f'Writing {out_file}.')
+    logger.debug('Writing %s', out_file)
     try:
         with open(out_file, 'wb') as wb:
             wb.write(serialized)
     except Exception as e:
         logger.exception(e)
     try:
         with open(out_file, 'wb') as wb:
             wb.write(serialized)
     except Exception as e:
         logger.exception(e)
-        logger.critical(f'Error writing {out_file}.  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
-
-    if stop_thread is not None:
-        stop_thread.set()
-        thread.join()
+        logger.critical('Error writing %s. Aborting.', out_file)
+        cleanup_and_exit(thread, stop_thread, 4)
+    cleanup_and_exit(thread, stop_thread, 0)
 
 
 if __name__ == '__main__':
 
 
 if __name__ == '__main__':