Making remote training work better.
[python_utils.git] / remote_worker.py
index 43b841589c670b758d52c777e716835b32863c51..84f8d56fa33318b507f3f11618c663861718182b 100755 (executable)
@@ -4,11 +4,11 @@
 results.
 """
 
+import logging
 import os
-import platform
 import signal
-import sys
 import threading
+import sys
 import time
 
 import cloudpickle  # type: ignore
@@ -20,6 +20,8 @@ import config
 from thread_utils import background_thread
 
 
+logger = logging.getLogger(__file__)
+
 cfg = config.add_commandline_args(
     f"Remote Worker ({__file__})",
     "Helper to run pickled code remotely and return results",
@@ -54,49 +56,65 @@ def watch_for_cancel(terminate_event: threading.Event) -> None:
         ancestors = p.parents()
         for ancestor in ancestors:
             name = ancestor.name()
-            if 'ssh' in name or 'Ssh' in name:
+            if 'ssh' in name.lower():
                 saw_sshd = True
                 break
         if not saw_sshd:
             os.system('pstree')
             os.kill(os.getpid(), signal.SIGTERM)
+            time.sleep(5.0)
+            os.kill(os.getpid(), signal.SIGKILL)
+            sys.exit(-1)
         if terminate_event.is_set():
             return
         time.sleep(1.0)
 
 
-if __name__ == '__main__':
-    @bootstrap.initialize
-    def main() -> None:
-        hostname = platform.node()
-
-        # Windows-Linux is retarded.
-    #    if (
-    #            hostname != 'VIDEO-COMPUTER' and
-    #            config.config['watch_for_cancel']
-    #    ):
-    #        (thread, terminate_event) = watch_for_cancel()
-
-        in_file = config.config['code_file']
-        out_file = config.config['result_file']
+def main() -> None:
+    in_file = config.config['code_file']
+    out_file = config.config['result_file']
 
+    logger.debug(f'Reading {in_file}.')
+    try:
         with open(in_file, 'rb') as rb:
             serialized = rb.read()
+    except Exception as e:
+        logger.exception(e)
+        logger.critical(f'Problem reading {in_file}.  Aborting.')
+        sys.exit(-1)
 
+    logger.debug(f'Deserializing {in_file}.')
+    try:
         fun, args, kwargs = cloudpickle.loads(serialized)
-        print(fun)
-        print(args)
-        print(kwargs)
-        print("Invoking the code...")
-        ret = fun(*args, **kwargs)
-
+    except Exception as e:
+        logger.exception(e)
+        logger.critical(f'Problem deserializing {in_file}.  Aborting.')
+        sys.exit(-1)
+
+    logger.debug('Invoking user code...')
+    start = time.time()
+    ret = fun(*args, **kwargs)
+    end = time.time()
+    logger.debug(f'User code took {end - start:.1f}s')
+
+    logger.debug('Serializing results')
+    try:
         serialized = cloudpickle.dumps(ret)
+    except Exception as e:
+        logger.exception(e)
+        logger.critical(f'Could not serialize result ({type(ret)}).  Aborting.')
+        sys.exit(-1)
+
+    logger.debug(f'Writing {out_file}.')
+    try:
         with open(out_file, 'wb') as wb:
             wb.write(serialized)
+    except Exception as e:
+        logger.exception(e)
+        logger.critical(f'Error writing {out_file}.  Aborting.')
+        sys.exit(-1)
 
-        # Windows-Linux is retarded.
-    #    if hostname != 'VIDEO-COMPUTER':
-    #        terminate_event.set()
-    #        thread.join()
-        sys.exit(0)
+
+if __name__ == '__main__':
     main()