From bef486c8c06e8d743a98b89910658a615acc8bbc Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Mon, 15 Nov 2021 16:55:26 -0800 Subject: [PATCH] Making remote training work better. --- executors.py | 18 ++++++----- logging_utils.py | 2 +- ml/model_trainer.py | 2 +- parallelize.py | 4 --- remote_worker.py | 74 ++++++++++++++++++++++++++++----------------- 5 files changed, 58 insertions(+), 42 deletions(-) diff --git a/executors.py b/executors.py index b16ad92..fe8d9d0 100644 --- a/executors.py +++ b/executors.py @@ -197,6 +197,11 @@ class ProcessExecutor(BaseExecutor): return state +class RemoteExecutorException(Exception): + """Thrown when a bundle cannot be executed despite several retries.""" + pass + + @dataclass class RemoteWorkerRecord: username: str @@ -508,7 +513,7 @@ class RemoteExecutor(BaseExecutor): if self.worker_count <= 0: msg = f"We need somewhere to schedule work; count was {self.worker_count}" logger.critical(msg) - raise Exception(msg) + raise RemoteExecutorException(msg) self.policy.register_worker_pool(self.workers) self.cv = threading.Condition() logger.debug(f'Creating {self.worker_count} local threads, one per remote worker.') @@ -518,9 +523,6 @@ class RemoteExecutor(BaseExecutor): ) self.status = RemoteExecutorStatus(self.worker_count) self.total_bundles_submitted = 0 - logger.debug( - f'Creating remote processpool with {self.worker_count} remote worker threads.' - ) def is_worker_available(self) -> bool: return self.policy.is_worker_available() @@ -556,7 +558,7 @@ class RemoteExecutor(BaseExecutor): # Regular progress report self.status.periodic_dump(self.total_bundles_submitted) - # Look for bundles to reschedule + # Look for bundles to reschedule. num_done = len(self.status.finished_bundle_timings) if num_done > 7 or (num_done > 5 and self.is_worker_available()): for worker, bundle_uuids in self.status.in_flight_bundles_by_worker.items(): @@ -663,7 +665,7 @@ class RemoteExecutor(BaseExecutor): logger.info(f"{uuid}/{fname}: Copying work to {worker} via {cmd}.") run_silently(cmd) xfer_latency = time.time() - start_ts - logger.info(f"{uuid}/{fname}: Copying done in {xfer_latency}s.") + logger.info(f"{uuid}/{fname}: Copying done in {xfer_latency:.1f}s.") except Exception as e: logger.exception(e) logger.error( @@ -969,8 +971,8 @@ class RemoteExecutor(BaseExecutor): f'{uuid}: Tried this bundle too many times already ({retry_limit}x); giving up.' ) if is_original: - logger.critical( - f'{uuid}: This is the original of the bundle; results will be incomplete.' + raise RemoteExecutorException( + f'{uuid}: This bundle can\'t be completed despite several backups and retries' ) else: logger.error(f'{uuid}: At least it\'s only a backup; better luck with the others.') diff --git a/logging_utils.py b/logging_utils.py index 7be31e3..819e3d3 100644 --- a/logging_utils.py +++ b/logging_utils.py @@ -76,7 +76,7 @@ cfg.add_argument( cfg.add_argument( '--logging_filename_count', type=int, - default=2, + default=7, metavar='COUNT', help='The number of logging_filename copies to keep before deleting.' ) diff --git a/ml/model_trainer.py b/ml/model_trainer.py index 9435351..f9e132e 100644 --- a/ml/model_trainer.py +++ b/ml/model_trainer.py @@ -218,7 +218,7 @@ class TrainingBlueprint(ABC): line = line.strip() try: (key, value) = line.split(self.spec.key_value_delimiter) - except Exception as e: + except Exception: logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped") continue diff --git a/parallelize.py b/parallelize.py index 0822095..d9c202f 100644 --- a/parallelize.py +++ b/parallelize.py @@ -6,10 +6,6 @@ from enum import Enum import functools import typing -ps_count = 0 -thread_count = 0 -remote_count = 0 - class Method(Enum): THREAD = 1 diff --git a/remote_worker.py b/remote_worker.py index 43b8415..84f8d56 100755 --- a/remote_worker.py +++ b/remote_worker.py @@ -4,11 +4,11 @@ results. """ +import logging import os -import platform import signal -import sys import threading +import sys import time import cloudpickle # type: ignore @@ -20,6 +20,8 @@ import config from thread_utils import background_thread +logger = logging.getLogger(__file__) + cfg = config.add_commandline_args( f"Remote Worker ({__file__})", "Helper to run pickled code remotely and return results", @@ -54,49 +56,65 @@ def watch_for_cancel(terminate_event: threading.Event) -> None: ancestors = p.parents() for ancestor in ancestors: name = ancestor.name() - if 'ssh' in name or 'Ssh' in name: + if 'ssh' in name.lower(): saw_sshd = True break if not saw_sshd: os.system('pstree') os.kill(os.getpid(), signal.SIGTERM) + time.sleep(5.0) + os.kill(os.getpid(), signal.SIGKILL) + sys.exit(-1) if terminate_event.is_set(): return time.sleep(1.0) -if __name__ == '__main__': - @bootstrap.initialize - def main() -> None: - hostname = platform.node() - - # Windows-Linux is retarded. - # if ( - # hostname != 'VIDEO-COMPUTER' and - # config.config['watch_for_cancel'] - # ): - # (thread, terminate_event) = watch_for_cancel() - - in_file = config.config['code_file'] - out_file = config.config['result_file'] +@bootstrap.initialize +def main() -> None: + in_file = config.config['code_file'] + out_file = config.config['result_file'] + logger.debug(f'Reading {in_file}.') + try: with open(in_file, 'rb') as rb: serialized = rb.read() + except Exception as e: + logger.exception(e) + logger.critical(f'Problem reading {in_file}. Aborting.') + sys.exit(-1) + logger.debug(f'Deserializing {in_file}.') + try: fun, args, kwargs = cloudpickle.loads(serialized) - print(fun) - print(args) - print(kwargs) - print("Invoking the code...") - ret = fun(*args, **kwargs) - + except Exception as e: + logger.exception(e) + logger.critical(f'Problem deserializing {in_file}. Aborting.') + sys.exit(-1) + + logger.debug('Invoking user code...') + start = time.time() + ret = fun(*args, **kwargs) + end = time.time() + logger.debug(f'User code took {end - start:.1f}s') + + logger.debug('Serializing results') + try: serialized = cloudpickle.dumps(ret) + except Exception as e: + logger.exception(e) + logger.critical(f'Could not serialize result ({type(ret)}). Aborting.') + sys.exit(-1) + + logger.debug(f'Writing {out_file}.') + try: with open(out_file, 'wb') as wb: wb.write(serialized) + except Exception as e: + logger.exception(e) + logger.critical(f'Error writing {out_file}. Aborting.') + sys.exit(-1) - # Windows-Linux is retarded. - # if hostname != 'VIDEO-COMPUTER': - # terminate_event.set() - # thread.join() - sys.exit(0) + +if __name__ == '__main__': main() -- 2.45.2