Used isort to sort imports. Also added to the git pre-commit hook.
[python_utils.git] / remote_worker.py
index b58c6ba0a66f8d32b2b81af72a66d23493c9b2e5..82b80ea3d722090ab7254eb24eac5884a9520172 100755 (executable)
@@ -7,9 +7,10 @@ results.
 import logging
 import os
 import signal
-import threading
 import sys
+import threading
 import time
+from typing import Optional
 
 import cloudpickle  # type: ignore
 import psutil  # type: ignore
@@ -17,9 +18,9 @@ import psutil  # type: ignore
 import argparse_utils
 import bootstrap
 import config
+from stopwatch import Timer
 from thread_utils import background_thread
 
-
 logger = logging.getLogger(__file__)
 
 cfg = config.add_commandline_args(
@@ -76,11 +77,24 @@ def watch_for_cancel(terminate_event: threading.Event) -> None:
         time.sleep(1.0)
 
 
+def cleanup_and_exit(
+    thread: Optional[threading.Thread],
+    stop_thread: Optional[threading.Event],
+    exit_code: int,
+) -> None:
+    if stop_thread is not None:
+        stop_thread.set()
+        assert thread is not None
+        thread.join()
+    sys.exit(exit_code)
+
+
 @bootstrap.initialize
 def main() -> None:
     in_file = config.config['code_file']
     out_file = config.config['result_file']
 
+    thread = None
     stop_thread = None
     if config.config['watch_for_cancel']:
         (thread, stop_thread) = watch_for_cancel()
@@ -92,8 +106,7 @@ def main() -> None:
     except Exception as e:
         logger.exception(e)
         logger.critical(f'Problem reading {in_file}.  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
+        cleanup_and_exit(thread, stop_thread, 1)
 
     logger.debug(f'Deserializing {in_file}.')
     try:
@@ -101,14 +114,12 @@ def main() -> None:
     except Exception as e:
         logger.exception(e)
         logger.critical(f'Problem deserializing {in_file}.  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
+        cleanup_and_exit(thread, stop_thread, 2)
 
     logger.debug('Invoking user code...')
-    start = time.time()
-    ret = fun(*args, **kwargs)
-    end = time.time()
-    logger.debug(f'User code took {end - start:.1f}s')
+    with Timer() as t:
+        ret = fun(*args, **kwargs)
+    logger.debug(f'User code took {t():.1f}s')
 
     logger.debug('Serializing results')
     try:
@@ -116,8 +127,7 @@ def main() -> None:
     except Exception as e:
         logger.exception(e)
         logger.critical(f'Could not serialize result ({type(ret)}).  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
+        cleanup_and_exit(thread, stop_thread, 3)
 
     logger.debug(f'Writing {out_file}.')
     try:
@@ -126,12 +136,8 @@ def main() -> None:
     except Exception as e:
         logger.exception(e)
         logger.critical(f'Error writing {out_file}.  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
-
-    if stop_thread is not None:
-        stop_thread.set()
-        thread.join()
+        cleanup_and_exit(thread, stop_thread, 4)
+    cleanup_and_exit(thread, stop_thread, 0)
 
 
 if __name__ == '__main__':