Making remote training work better.
authorScott Gasch <[email protected]>
Tue, 16 Nov 2021 00:55:26 +0000 (16:55 -0800)
committerScott Gasch <[email protected]>
Tue, 16 Nov 2021 00:55:26 +0000 (16:55 -0800)
executors.py
logging_utils.py
ml/model_trainer.py
parallelize.py
remote_worker.py

index b16ad92d80a624c466b6d54c5830d5a2f00c8789..fe8d9d0d8e749b0aa85609d04c3444e35b6e89d3 100644 (file)
@@ -197,6 +197,11 @@ class ProcessExecutor(BaseExecutor):
         return state
 
 
+class RemoteExecutorException(Exception):
+    """Thrown when a bundle cannot be executed despite several retries."""
+    pass
+
+
 @dataclass
 class RemoteWorkerRecord:
     username: str
@@ -508,7 +513,7 @@ class RemoteExecutor(BaseExecutor):
         if self.worker_count <= 0:
             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
             logger.critical(msg)
-            raise Exception(msg)
+            raise RemoteExecutorException(msg)
         self.policy.register_worker_pool(self.workers)
         self.cv = threading.Condition()
         logger.debug(f'Creating {self.worker_count} local threads, one per remote worker.')
@@ -518,9 +523,6 @@ class RemoteExecutor(BaseExecutor):
         )
         self.status = RemoteExecutorStatus(self.worker_count)
         self.total_bundles_submitted = 0
-        logger.debug(
-            f'Creating remote processpool with {self.worker_count} remote worker threads.'
-        )
 
     def is_worker_available(self) -> bool:
         return self.policy.is_worker_available()
@@ -556,7 +558,7 @@ class RemoteExecutor(BaseExecutor):
             # Regular progress report
             self.status.periodic_dump(self.total_bundles_submitted)
 
-            # Look for bundles to reschedule
+            # Look for bundles to reschedule.
             num_done = len(self.status.finished_bundle_timings)
             if num_done > 7 or (num_done > 5 and self.is_worker_available()):
                 for worker, bundle_uuids in self.status.in_flight_bundles_by_worker.items():
@@ -663,7 +665,7 @@ class RemoteExecutor(BaseExecutor):
                 logger.info(f"{uuid}/{fname}: Copying work to {worker} via {cmd}.")
                 run_silently(cmd)
                 xfer_latency = time.time() - start_ts
-                logger.info(f"{uuid}/{fname}: Copying done in {xfer_latency}s.")
+                logger.info(f"{uuid}/{fname}: Copying done in {xfer_latency:.1f}s.")
             except Exception as e:
                 logger.exception(e)
                 logger.error(
@@ -969,8 +971,8 @@ class RemoteExecutor(BaseExecutor):
                 f'{uuid}: Tried this bundle too many times already ({retry_limit}x); giving up.'
             )
             if is_original:
-                logger.critical(
-                    f'{uuid}: This is the original of the bundle; results will be incomplete.'
+                raise RemoteExecutorException(
+                    f'{uuid}: This bundle can\'t be completed despite several backups and retries'
                 )
             else:
                 logger.error(f'{uuid}: At least it\'s only a backup; better luck with the others.')
index 7be31e3c7c7d530215538e53922ae2070a70b342..819e3d3ee780a78cc903a890eba03e533b608870 100644 (file)
@@ -76,7 +76,7 @@ cfg.add_argument(
 cfg.add_argument(
     '--logging_filename_count',
     type=int,
-    default=2,
+    default=7,
     metavar='COUNT',
     help='The number of logging_filename copies to keep before deleting.'
 )
index 9435351494824de533f38ad324484f05cf590b03..f9e132e18aa20ecf2461db55257b6037a0c13a4e 100644 (file)
@@ -218,7 +218,7 @@ class TrainingBlueprint(ABC):
                 line = line.strip()
                 try:
                     (key, value) = line.split(self.spec.key_value_delimiter)
-                except Exception as e:
+                except Exception:
                     logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
                     continue
 
index 08220951a000e3ee9c24dcc908f54af1067ee265..d9c202faf00d56cf4cfc43e36c821998c249c7c4 100644 (file)
@@ -6,10 +6,6 @@ from enum import Enum
 import functools
 import typing
 
-ps_count = 0
-thread_count = 0
-remote_count = 0
-
 
 class Method(Enum):
     THREAD = 1
index 43b841589c670b758d52c777e716835b32863c51..84f8d56fa33318b507f3f11618c663861718182b 100755 (executable)
@@ -4,11 +4,11 @@
 results.
 """
 
+import logging
 import os
-import platform
 import signal
-import sys
 import threading
+import sys
 import time
 
 import cloudpickle  # type: ignore
@@ -20,6 +20,8 @@ import config
 from thread_utils import background_thread
 
 
+logger = logging.getLogger(__file__)
+
 cfg = config.add_commandline_args(
     f"Remote Worker ({__file__})",
     "Helper to run pickled code remotely and return results",
@@ -54,49 +56,65 @@ def watch_for_cancel(terminate_event: threading.Event) -> None:
         ancestors = p.parents()
         for ancestor in ancestors:
             name = ancestor.name()
-            if 'ssh' in name or 'Ssh' in name:
+            if 'ssh' in name.lower():
                 saw_sshd = True
                 break
         if not saw_sshd:
             os.system('pstree')
             os.kill(os.getpid(), signal.SIGTERM)
+            time.sleep(5.0)
+            os.kill(os.getpid(), signal.SIGKILL)
+            sys.exit(-1)
         if terminate_event.is_set():
             return
         time.sleep(1.0)
 
 
-if __name__ == '__main__':
-    @bootstrap.initialize
-    def main() -> None:
-        hostname = platform.node()
-
-        # Windows-Linux is retarded.
-    #    if (
-    #            hostname != 'VIDEO-COMPUTER' and
-    #            config.config['watch_for_cancel']
-    #    ):
-    #        (thread, terminate_event) = watch_for_cancel()
-
-        in_file = config.config['code_file']
-        out_file = config.config['result_file']
+def main() -> None:
+    in_file = config.config['code_file']
+    out_file = config.config['result_file']
 
+    logger.debug(f'Reading {in_file}.')
+    try:
         with open(in_file, 'rb') as rb:
             serialized = rb.read()
+    except Exception as e:
+        logger.exception(e)
+        logger.critical(f'Problem reading {in_file}.  Aborting.')
+        sys.exit(-1)
 
+    logger.debug(f'Deserializing {in_file}.')
+    try:
         fun, args, kwargs = cloudpickle.loads(serialized)
-        print(fun)
-        print(args)
-        print(kwargs)
-        print("Invoking the code...")
-        ret = fun(*args, **kwargs)
-
+    except Exception as e:
+        logger.exception(e)
+        logger.critical(f'Problem deserializing {in_file}.  Aborting.')
+        sys.exit(-1)
+
+    logger.debug('Invoking user code...')
+    start = time.time()
+    ret = fun(*args, **kwargs)
+    end = time.time()
+    logger.debug(f'User code took {end - start:.1f}s')
+
+    logger.debug('Serializing results')
+    try:
         serialized = cloudpickle.dumps(ret)
+    except Exception as e:
+        logger.exception(e)
+        logger.critical(f'Could not serialize result ({type(ret)}).  Aborting.')
+        sys.exit(-1)
+
+    logger.debug(f'Writing {out_file}.')
+    try:
         with open(out_file, 'wb') as wb:
             wb.write(serialized)
+    except Exception as e:
+        logger.exception(e)
+        logger.critical(f'Error writing {out_file}.  Aborting.')
+        sys.exit(-1)
 
-        # Windows-Linux is retarded.
-    #    if hostname != 'VIDEO-COMPUTER':
-    #        terminate_event.set()
-    #        thread.join()
-        sys.exit(0)
+
+if __name__ == '__main__':
     main()