Minor cleanup.
authorScott <[email protected]>
Sun, 23 Jan 2022 20:58:13 +0000 (12:58 -0800)
committerScott <[email protected]>
Sun, 23 Jan 2022 20:58:13 +0000 (12:58 -0800)
executors.py
ml/model_trainer.py

index c91c2a6af535502ad31db6bb233fa9145d43dd4a..e092e1058bac0d7ba754aaefbfefcfe51f3ffcec 100644 (file)
@@ -32,15 +32,14 @@ from thread_utils import background_thread
 logger = logging.getLogger(__name__)
 
 parser = config.add_commandline_args(
-    f"Executors ({__file__})",
-    "Args related to processing executors."
+    f"Executors ({__file__})", "Args related to processing executors."
 )
 parser.add_argument(
     '--executors_threadpool_size',
     type=int,
     metavar='#THREADS',
     help='Number of threads in the default threadpool, leave unset for default',
-    default=None
+    default=None,
 )
 parser.add_argument(
     '--executors_processpool_size',
@@ -77,21 +76,15 @@ class BaseExecutor(ABC):
         self.title = title
         self.task_count = 0
         self.histogram = hist.SimpleHistogram(
-            hist.SimpleHistogram.n_evenly_spaced_buckets(
-                int(0), int(500), 50
-            )
+            hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
         )
 
     @abstractmethod
-    def submit(self,
-               function: Callable,
-               *args,
-               **kwargs) -> fut.Future:
+    def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
         pass
 
     @abstractmethod
-    def shutdown(self,
-                 wait: bool = True) -> None:
+    def shutdown(self, wait: bool = True) -> None:
         pass
 
     def adjust_task_count(self, delta: int) -> None:
@@ -100,8 +93,7 @@ class BaseExecutor(ABC):
 
 
 class ThreadExecutor(BaseExecutor):
-    def __init__(self,
-                 max_workers: Optional[int] = None):
+    def __init__(self, max_workers: Optional[int] = None):
         super().__init__()
         workers = None
         if max_workers is not None:
@@ -110,8 +102,7 @@ class ThreadExecutor(BaseExecutor):
             workers = config.config['executors_threadpool_size']
         logger.debug(f'Creating threadpool executor with {workers} workers')
         self._thread_pool_executor = fut.ThreadPoolExecutor(
-            max_workers=workers,
-            thread_name_prefix="thread_executor_helper"
+            max_workers=workers, thread_name_prefix="thread_executor_helper"
         )
 
     def run_local_bundle(self, fun, *args, **kwargs):
@@ -126,31 +117,25 @@ class ThreadExecutor(BaseExecutor):
         return result
 
     @overrides
-    def submit(self,
-               function: Callable,
-               *args,
-               **kwargs) -> fut.Future:
+    def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
         self.adjust_task_count(+1)
         newargs = []
         newargs.append(function)
         for arg in args:
             newargs.append(arg)
         return self._thread_pool_executor.submit(
-            self.run_local_bundle,
-            *newargs,
-            **kwargs)
+            self.run_local_bundle, *newargs, **kwargs
+        )
 
     @overrides
-    def shutdown(self,
-                 wait = True) -> None:
+    def shutdown(self, wait=True) -> None:
         logger.debug(f'Shutting down threadpool executor {self.title}')
         print(self.histogram)
         self._thread_pool_executor.shutdown(wait)
 
 
 class ProcessExecutor(BaseExecutor):
-    def __init__(self,
-                 max_workers=None):
+    def __init__(self, max_workers=None):
         super().__init__()
         workers = None
         if max_workers is not None:
@@ -170,22 +155,12 @@ class ProcessExecutor(BaseExecutor):
         return result
 
     @overrides
-    def submit(self,
-               function: Callable,
-               *args,
-               **kwargs) -> fut.Future:
+    def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
         start = time.time()
         self.adjust_task_count(+1)
         pickle = make_cloud_pickle(function, *args, **kwargs)
-        result = self._process_executor.submit(
-            self.run_cloud_pickle,
-            pickle
-        )
-        result.add_done_callback(
-            lambda _: self.histogram.add_item(
-                time.time() - start
-            )
-        )
+        result = self._process_executor.submit(self.run_cloud_pickle, pickle)
+        result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
         return result
 
     @overrides
@@ -202,6 +177,7 @@ class ProcessExecutor(BaseExecutor):
 
 class RemoteExecutorException(Exception):
     """Thrown when a bundle cannot be executed despite several retries."""
+
     pass
 
 
@@ -277,13 +253,9 @@ class RemoteExecutorStatus:
         self.start_per_bundle: Dict[str, float] = defaultdict(float)
         self.end_per_bundle: Dict[str, float] = defaultdict(float)
         self.finished_bundle_timings_per_worker: Dict[
-            RemoteWorkerRecord,
-            List[float]
-        ] = {}
-        self.in_flight_bundles_by_worker: Dict[
-            RemoteWorkerRecord,
-            Set[str]
+            RemoteWorkerRecord, List[float]
         ] = {}
+        self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
         self.finished_bundle_timings: List[float] = []
         self.last_periodic_dump: Optional[float] = None
@@ -293,21 +265,12 @@ class RemoteExecutorStatus:
         # as a memory fence for modifications to bundle.
         self.lock = threading.Lock()
 
-    def record_acquire_worker(
-            self,
-            worker: RemoteWorkerRecord,
-            uuid: str
-    ) -> None:
+    def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
         with self.lock:
-            self.record_acquire_worker_already_locked(
-                worker,
-                uuid
-            )
+            self.record_acquire_worker_already_locked(worker, uuid)
 
     def record_acquire_worker_already_locked(
-            self,
-            worker: RemoteWorkerRecord,
-            uuid: str
+        self, worker: RemoteWorkerRecord, uuid: str
     ) -> None:
         assert self.lock.locked()
         self.known_workers.add(worker)
@@ -316,34 +279,28 @@ class RemoteExecutorStatus:
         x.add(uuid)
         self.in_flight_bundles_by_worker[worker] = x
 
-    def record_bundle_details(
-            self,
-            details: BundleDetails) -> None:
+    def record_bundle_details(self, details: BundleDetails) -> None:
         with self.lock:
             self.record_bundle_details_already_locked(details)
 
-    def record_bundle_details_already_locked(
-            self,
-            details: BundleDetails) -> None:
+    def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
         assert self.lock.locked()
         self.bundle_details_by_uuid[details.uuid] = details
 
     def record_release_worker(
-            self,
-            worker: RemoteWorkerRecord,
-            uuid: str,
-            was_cancelled: bool,
+        self,
+        worker: RemoteWorkerRecord,
+        uuid: str,
+        was_cancelled: bool,
     ) -> None:
         with self.lock:
-            self.record_release_worker_already_locked(
-                worker, uuid, was_cancelled
-            )
+            self.record_release_worker_already_locked(worker, uuid, was_cancelled)
 
     def record_release_worker_already_locked(
-            self,
-            worker: RemoteWorkerRecord,
-            uuid: str,
-            was_cancelled: bool,
+        self,
+        worker: RemoteWorkerRecord,
+        uuid: str,
+        was_cancelled: bool,
     ) -> None:
         assert self.lock.locked()
         ts = time.time()
@@ -407,10 +364,7 @@ class RemoteExecutorStatus:
             if in_flight > 0:
                 ret += f'    ...{in_flight} bundles currently in flight:\n'
                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
-                    details = self.bundle_details_by_uuid.get(
-                        bundle_uuid,
-                        None
-                    )
+                    details = self.bundle_details_by_uuid.get(bundle_uuid, None)
                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
                     if self.start_per_bundle[bundle_uuid] is not None:
                         sec = ts - self.start_per_bundle[bundle_uuid]
@@ -442,10 +396,7 @@ class RemoteExecutorStatus:
         assert self.lock.locked()
         self.total_bundles_submitted = total_bundles_submitted
         ts = time.time()
-        if (
-                self.last_periodic_dump is None
-                or ts - self.last_periodic_dump > 5.0
-        ):
+        if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
             print(self)
             self.last_periodic_dump = ts
 
@@ -459,10 +410,7 @@ class RemoteWorkerSelectionPolicy(ABC):
         pass
 
     @abstractmethod
-    def acquire_worker(
-            self,
-            machine_to_avoid = None
-    ) -> Optional[RemoteWorkerRecord]:
+    def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
         pass
 
 
@@ -475,10 +423,7 @@ class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
         return False
 
     @overrides
-    def acquire_worker(
-            self,
-            machine_to_avoid = None
-    ) -> Optional[RemoteWorkerRecord]:
+    def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
         grabbag = []
         for worker in self.workers:
             for x in range(0, worker.count):
@@ -511,8 +456,7 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
 
     @overrides
     def acquire_worker(
-            self,
-            machine_to_avoid: str = None
+        self, machine_to_avoid: str = None
     ) -> Optional[RemoteWorkerRecord]:
         x = self.index
         while True:
@@ -535,9 +479,9 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
 
 
 class RemoteExecutor(BaseExecutor):
-    def __init__(self,
-                 workers: List[RemoteWorkerRecord],
-                 policy: RemoteWorkerSelectionPolicy) -> None:
+    def __init__(
+        self, workers: List[RemoteWorkerRecord], policy: RemoteWorkerSelectionPolicy
+    ) -> None:
         super().__init__()
         self.workers = workers
         self.policy = policy
@@ -550,7 +494,9 @@ class RemoteExecutor(BaseExecutor):
             raise RemoteExecutorException(msg)
         self.policy.register_worker_pool(self.workers)
         self.cv = threading.Condition()
-        logger.debug(f'Creating {self.worker_count} local threads, one per remote worker.')
+        logger.debug(
+            f'Creating {self.worker_count} local threads, one per remote worker.'
+        )
         self._helper_executor = fut.ThreadPoolExecutor(
             thread_name_prefix="remote_executor_helper",
             max_workers=self.worker_count,
@@ -559,7 +505,10 @@ class RemoteExecutor(BaseExecutor):
         self.total_bundles_submitted = 0
         self.backup_lock = threading.Lock()
         self.last_backup = None
-        (self.heartbeat_thread, self.heartbeat_stop_event) = self.run_periodic_heartbeat()
+        (
+            self.heartbeat_thread,
+            self.heartbeat_stop_event,
+        ) = self.run_periodic_heartbeat()
 
     @background_thread
     def run_periodic_heartbeat(self, stop_event) -> None:
@@ -584,17 +533,20 @@ class RemoteExecutor(BaseExecutor):
         num_idle_workers = self.worker_count - self.task_count
         now = time.time()
         if (
-                num_done > 2
-                and num_idle_workers > 1
-                and (self.last_backup is None or (now - self.last_backup > 6.0))
-                and self.backup_lock.acquire(blocking=False)
+            num_done > 2
+            and num_idle_workers > 1
+            and (self.last_backup is None or (now - self.last_backup > 6.0))
+            and self.backup_lock.acquire(blocking=False)
         ):
             try:
                 assert self.backup_lock.locked()
 
                 bundle_to_backup = None
                 best_score = None
-                for worker, bundle_uuids in self.status.in_flight_bundles_by_worker.items():
+                for (
+                    worker,
+                    bundle_uuids,
+                ) in self.status.in_flight_bundles_by_worker.items():
 
                     # Prefer to schedule backups of bundles running on
                     # slower machines.
@@ -610,9 +562,9 @@ class RemoteExecutor(BaseExecutor):
                     for uuid in bundle_uuids:
                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
                         if (
-                                bundle is not None
-                                and bundle.src_bundle is None
-                                and bundle.backup_bundles is not None
+                            bundle is not None
+                            and bundle.src_bundle is None
+                            and bundle.backup_bundles is not None
                         ):
                             score = base_score
 
@@ -623,15 +575,21 @@ class RemoteExecutor(BaseExecutor):
                             if start_ts is not None:
                                 runtime = now - start_ts
                                 score += runtime
-                                logger.debug(f'score[{bundle}] => {score}  # latency boost')
+                                logger.debug(
+                                    f'score[{bundle}] => {score}  # latency boost'
+                                )
 
                                 if bundle.slower_than_local_p95:
                                     score += runtime / 2
-                                    logger.debug(f'score[{bundle}] => {score}  # >worker p95')
+                                    logger.debug(
+                                        f'score[{bundle}] => {score}  # >worker p95'
+                                    )
 
                                 if bundle.slower_than_global_p95:
                                     score += runtime / 4
-                                    logger.debug(f'score[{bundle}] => {score}  # >global p95')
+                                    logger.debug(
+                                        f'score[{bundle}] => {score}  # >global p95'
+                                    )
 
                             # Prefer backups of bundles that don't
                             # have backups already.
@@ -648,9 +606,8 @@ class RemoteExecutor(BaseExecutor):
                                 f'score[{bundle}] => {score}  # {backup_count} dup backup factor'
                             )
 
-                            if (
-                                    score != 0
-                                    and (best_score is None or score > best_score)
+                            if score != 0 and (
+                                best_score is None or score > best_score
                             ):
                                 bundle_to_backup = bundle
                                 assert bundle is not None
@@ -677,14 +634,12 @@ class RemoteExecutor(BaseExecutor):
         return self.policy.is_worker_available()
 
     def acquire_worker(
-            self,
-            machine_to_avoid: str = None
+        self, machine_to_avoid: str = None
     ) -> Optional[RemoteWorkerRecord]:
         return self.policy.acquire_worker(machine_to_avoid)
 
     def find_available_worker_or_block(
-            self,
-            machine_to_avoid: str = None
+        self, machine_to_avoid: str = None
     ) -> RemoteWorkerRecord:
         with self.cv:
             while not self.is_worker_available():
@@ -696,12 +651,7 @@ class RemoteExecutor(BaseExecutor):
         logger.critical(msg)
         raise Exception(msg)
 
-    def release_worker(
-            self,
-            bundle: BundleDetails,
-            *,
-            was_cancelled=True
-    ) -> None:
+    def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
         worker = bundle.worker
         assert worker is not None
         logger.debug(f'Released worker {worker}')
@@ -768,8 +718,8 @@ class RemoteExecutor(BaseExecutor):
                     # thing.
                     logger.exception(e)
                     logger.error(
-                        f'{bundle}: We are the original owner thread and yet there are ' +
-                        'no results for this bundle.  This is unexpected and bad.'
+                        f'{bundle}: We are the original owner thread and yet there are '
+                        'no results for this bundle.  This is unexpected and bad.'
                     )
                     return self.emergency_retry_nasty_bundle(bundle)
                 else:
@@ -785,12 +735,14 @@ class RemoteExecutor(BaseExecutor):
         # Send input code / data to worker machine if it's not local.
         if hostname not in machine:
             try:
-                cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
+                cmd = (
+                    f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
+                )
                 start_ts = time.time()
                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
                 run_silently(cmd)
                 xfer_latency = time.time() - start_ts
-                logger.info(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.")
+                logger.debug(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.")
             except Exception as e:
                 self.release_worker(bundle)
                 if is_original:
@@ -798,9 +750,9 @@ class RemoteExecutor(BaseExecutor):
                     # And we're the original bundle.  We have to retry.
                     logger.exception(e)
                     logger.error(
-                        f"{bundle}: Failed to send instructions to the worker machine?! " +
-                        "This is not expected; we\'re the original bundle so this shouldn\'t " +
-                        "be a race condition.  Attempting an emergency retry..."
+                        f"{bundle}: Failed to send instructions to the worker machine?! "
+                        + "This is not expected; we\'re the original bundle so this shouldn\'t "
+                        "be a race condition.  Attempting an emergency retry..."
                     )
                     return self.emergency_retry_nasty_bundle(bundle)
                 else:
@@ -817,17 +769,23 @@ class RemoteExecutor(BaseExecutor):
         # Kick off the work.  Note that if this fails we let
         # wait_for_process deal with it.
         self.status.record_processing_began(uuid)
-        cmd = (f'{SSH} {bundle.username}@{bundle.machine} '
-               f'"source py38-venv/bin/activate &&'
-               f' /home/scott/lib/python_modules/remote_worker.py'
-               f' --code_file {bundle.code_file} --result_file {bundle.result_file}"')
+        cmd = (
+            f'{SSH} {bundle.username}@{bundle.machine} '
+            f'"source py38-venv/bin/activate &&'
+            f' /home/scott/lib/python_modules/remote_worker.py'
+            f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
+        )
         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
         p = cmd_in_background(cmd, silent=True)
         bundle.pid = p.pid
-        logger.debug(f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.')
+        logger.debug(
+            f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.'
+        )
         return self.wait_for_process(p, bundle, 0)
 
-    def wait_for_process(self, p: subprocess.Popen, bundle: BundleDetails, depth: int) -> Any:
+    def wait_for_process(
+        self, p: subprocess.Popen, bundle: BundleDetails, depth: int
+    ) -> Any:
         machine = bundle.machine
         pid = p.pid
         if depth > 3:
@@ -847,13 +805,11 @@ class RemoteExecutor(BaseExecutor):
             except subprocess.TimeoutExpired:
                 if self.check_if_cancelled(bundle):
                     logger.info(
-                        f'{bundle}: another worker finished bundle, checking it out...'
+                        f'{bundle}: looks like another worker finished bundle...'
                     )
                     break
             else:
-                logger.info(
-                    f"{bundle}: pid {pid} ({machine}) is finished!"
-                )
+                logger.info(f"{bundle}: pid {pid} ({machine}) is finished!")
                 p = None
                 break
 
@@ -913,13 +869,17 @@ class RemoteExecutor(BaseExecutor):
                         except Exception as e:
                             attempts += 1
                             if attempts >= 3:
-                                raise(e)
+                                raise (e)
                         else:
                             break
 
-                    run_silently(f'{SSH} {username}@{machine}'
-                                 f' "/bin/rm -f {code_file} {result_file}"')
-                    logger.debug(f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.')
+                    run_silently(
+                        f'{SSH} {username}@{machine}'
+                        f' "/bin/rm -f {code_file} {result_file}"'
+                    )
+                    logger.debug(
+                        f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.'
+                    )
                 dur = bundle.end_ts - bundle.start_ts
                 self.histogram.add_item(dur)
 
@@ -944,9 +904,7 @@ class RemoteExecutor(BaseExecutor):
                 # Re-raise the exception; the code in wait_for_process may
                 # decide to emergency_retry_nasty_bundle here.
                 raise Exception(e)
-            logger.debug(
-                f'Removing local (master) {code_file} and {result_file}.'
-            )
+            logger.debug(f'Removing local (master) {code_file} and {result_file}.')
             os.remove(f'{result_file}')
             os.remove(f'{code_file}')
 
@@ -980,6 +938,7 @@ class RemoteExecutor(BaseExecutor):
 
     def create_original_bundle(self, pickle, fname: str):
         from string_utils import generate_uuid
+
         uuid = generate_uuid(omit_dashes=True)
         code_file = f'/tmp/{uuid}.code.bin'
         result_file = f'/tmp/{uuid}.result.bin'
@@ -989,25 +948,25 @@ class RemoteExecutor(BaseExecutor):
             wb.write(pickle)
 
         bundle = BundleDetails(
-            pickled_code = pickle,
-            uuid = uuid,
-            fname = fname,
-            worker = None,
-            username = None,
-            machine = None,
-            hostname = platform.node(),
-            code_file = code_file,
-            result_file = result_file,
-            pid = 0,
-            start_ts = time.time(),
-            end_ts = 0.0,
-            slower_than_local_p95 = False,
-            slower_than_global_p95 = False,
-            src_bundle = None,
-            is_cancelled = threading.Event(),
-            was_cancelled = False,
-            backup_bundles = [],
-            failure_count = 0,
+            pickled_code=pickle,
+            uuid=uuid,
+            fname=fname,
+            worker=None,
+            username=None,
+            machine=None,
+            hostname=platform.node(),
+            code_file=code_file,
+            result_file=result_file,
+            pid=0,
+            start_ts=time.time(),
+            end_ts=0.0,
+            slower_than_local_p95=False,
+            slower_than_global_p95=False,
+            src_bundle=None,
+            is_cancelled=threading.Event(),
+            was_cancelled=False,
+            backup_bundles=[],
+            failure_count=0,
         )
         self.status.record_bundle_details(bundle)
         logger.debug(f'{bundle}: Created an original bundle')
@@ -1019,33 +978,32 @@ class RemoteExecutor(BaseExecutor):
         uuid = src_bundle.uuid + f'_backup#{n}'
 
         backup_bundle = BundleDetails(
-            pickled_code = src_bundle.pickled_code,
-            uuid = uuid,
-            fname = src_bundle.fname,
-            worker = None,
-            username = None,
-            machine = None,
-            hostname = src_bundle.hostname,
-            code_file = src_bundle.code_file,
-            result_file = src_bundle.result_file,
-            pid = 0,
-            start_ts = time.time(),
-            end_ts = 0.0,
-            slower_than_local_p95 = False,
-            slower_than_global_p95 = False,
-            src_bundle = src_bundle,
-            is_cancelled = threading.Event(),
-            was_cancelled = False,
-            backup_bundles = None,    # backup backups not allowed
-            failure_count = 0,
+            pickled_code=src_bundle.pickled_code,
+            uuid=uuid,
+            fname=src_bundle.fname,
+            worker=None,
+            username=None,
+            machine=None,
+            hostname=src_bundle.hostname,
+            code_file=src_bundle.code_file,
+            result_file=src_bundle.result_file,
+            pid=0,
+            start_ts=time.time(),
+            end_ts=0.0,
+            slower_than_local_p95=False,
+            slower_than_global_p95=False,
+            src_bundle=src_bundle,
+            is_cancelled=threading.Event(),
+            was_cancelled=False,
+            backup_bundles=None,  # backup backups not allowed
+            failure_count=0,
         )
         src_bundle.backup_bundles.append(backup_bundle)
         self.status.record_bundle_details_already_locked(backup_bundle)
         logger.debug(f'{backup_bundle}: Created a backup bundle')
         return backup_bundle
 
-    def schedule_backup_for_bundle(self,
-                                   src_bundle: BundleDetails):
+    def schedule_backup_for_bundle(self, src_bundle: BundleDetails):
         assert self.status.lock.locked()
         assert src_bundle is not None
         backup_bundle = self.create_backup_bundle(src_bundle)
@@ -1079,7 +1037,9 @@ class RemoteExecutor(BaseExecutor):
                     f'{bundle}: This bundle can\'t be completed despite several backups and retries'
                 )
             else:
-                logger.error(f'{bundle}: At least it\'s only a backup; better luck with the others.')
+                logger.error(
+                    f'{bundle}: At least it\'s only a backup; better luck with the others.'
+                )
             return None
         else:
             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
@@ -1088,10 +1048,7 @@ class RemoteExecutor(BaseExecutor):
             return self.launch(bundle, avoid_last_machine)
 
     @overrides
-    def submit(self,
-               function: Callable,
-               *args,
-               **kwargs) -> fut.Future:
+    def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
         pickle = make_cloud_pickle(function, *args, **kwargs)
         bundle = self.create_original_bundle(pickle, function.__name__)
         self.total_bundles_submitted += 1
@@ -1117,8 +1074,7 @@ class DefaultExecutors(object):
         logger.debug(f'RUN> ping -c 1 {host}')
         try:
             x = cmd_with_timeout(
-                f'ping -c 1 {host} >/dev/null 2>/dev/null',
-                timeout_seconds=1.0
+                f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0
             )
             return x == 0
         except Exception:
@@ -1142,50 +1098,50 @@ class DefaultExecutors(object):
                 logger.info('Found cheetah.house')
                 pool.append(
                     RemoteWorkerRecord(
-                        username = 'scott',
-                        machine = 'cheetah.house',
-                        weight = 25,
-                        count = 6,
+                        username='scott',
+                        machine='cheetah.house',
+                        weight=25,
+                        count=6,
                     ),
                 )
             if self.ping('meerkat.cabin'):
                 logger.info('Found meerkat.cabin')
                 pool.append(
                     RemoteWorkerRecord(
-                        username = 'scott',
-                        machine = 'meerkat.cabin',
-                        weight = 12,
-                        count = 2,
+                        username='scott',
+                        machine='meerkat.cabin',
+                        weight=12,
+                        count=2,
                     ),
                 )
             if self.ping('wannabe.house'):
                 logger.info('Found wannabe.house')
                 pool.append(
                     RemoteWorkerRecord(
-                        username = 'scott',
-                        machine = 'wannabe.house',
-                        weight = 30,
-                        count = 10,
+                        username='scott',
+                        machine='wannabe.house',
+                        weight=30,
+                        count=10,
                     ),
                 )
             if self.ping('puma.cabin'):
                 logger.info('Found puma.cabin')
                 pool.append(
                     RemoteWorkerRecord(
-                        username = 'scott',
-                        machine = 'puma.cabin',
-                        weight = 25,
-                        count = 6,
+                        username='scott',
+                        machine='puma.cabin',
+                        weight=25,
+                        count=6,
                     ),
                 )
             if self.ping('backup.house'):
                 logger.info('Found backup.house')
                 pool.append(
                     RemoteWorkerRecord(
-                        username = 'scott',
-                        machine = 'backup.house',
-                        weight = 7,
-                        count = 2,
+                        username='scott',
+                        machine='backup.house',
+                        weight=7,
+                        count=2,
                     ),
                 )
 
index 041f0f805cc5958fb948a48cc9bc160bc8578956..213a1814cff5e98507e30c19e17669ab123886ce 100644 (file)
@@ -28,18 +28,17 @@ import parallelize as par
 logger = logging.getLogger(__file__)
 
 parser = config.add_commandline_args(
-    f"ML Model Trainer ({__file__})",
-    "Arguments related to training an ML model"
+    f"ML Model Trainer ({__file__})", "Arguments related to training an ML model"
 )
 parser.add_argument(
     "--ml_trainer_quiet",
     action="store_true",
-    help="Don't prompt the user for anything."
+    help="Don't prompt the user for anything.",
 )
 parser.add_argument(
     "--ml_trainer_delete",
     action="store_true",
-    help="Delete invalid/incomplete features files in addition to warning."
+    help="Delete invalid/incomplete features files in addition to warning.",
 )
 group = parser.add_mutually_exclusive_group()
 group.add_argument(
@@ -71,10 +70,10 @@ class InputSpec(SimpleNamespace):
     @staticmethod
     def populate_from_config() -> InputSpec:
         return InputSpec(
-            dry_run = config.config["ml_trainer_dry_run"],
-            quiet = config.config["ml_trainer_quiet"],
-            persist_percentage_threshold = config.config["ml_trainer_persist_threshold"],
-            delete_bad_inputs = config.config["ml_trainer_delete"],
+            dry_run=config.config["ml_trainer_dry_run"],
+            quiet=config.config["ml_trainer_quiet"],
+            persist_percentage_threshold=config.config["ml_trainer_persist_threshold"],
+            delete_bad_inputs=config.config["ml_trainer_delete"],
         )
 
 
@@ -127,11 +126,7 @@ class TrainingBlueprint(ABC):
         models = []
         modelid_to_params = {}
         for params in self.spec.training_parameters:
-            model = self.train_model(
-                params,
-                self.X_train_scaled,
-                self.y_train
-            )
+            model = self.train_model(params, self.X_train_scaled, self.y_train)
             models.append(model)
             modelid_to_params[model.get_id()] = str(params)
 
@@ -167,9 +162,7 @@ class TrainingBlueprint(ABC):
                     best_model = model
                     best_params = params
                     if not self.spec.quiet:
-                        print(
-                            f"New best score {best_score:.2f}% with params {params}"
-                        )
+                        print(f"New best score {best_score:.2f}% with params {params}")
 
         if not self.spec.quiet:
             executors.DefaultExecutors().shutdown()
@@ -177,30 +170,28 @@ class TrainingBlueprint(ABC):
             print(msg)
             logger.info(msg)
 
-        scaler_filename, model_filename, model_info_filename = (
-            self.maybe_persist_scaler_and_model(
-                best_training_score,
-                best_test_score,
-                best_params,
-                num_examples,
-                scaler,
-                best_model,
-            )
+        (
+            scaler_filename,
+            model_filename,
+            model_info_filename,
+        ) = self.maybe_persist_scaler_and_model(
+            best_training_score,
+            best_test_score,
+            best_params,
+            num_examples,
+            scaler,
+            best_model,
         )
         return OutputSpec(
-            model_filename = model_filename,
-            model_info_filename = model_info_filename,
-            scaler_filename = scaler_filename,
-            training_score = best_training_score,
-            test_score = best_test_score,
+            model_filename=model_filename,
+            model_info_filename=model_info_filename,
+            scaler_filename=scaler_filename,
+            training_score=best_training_score,
+            test_score=best_test_score,
         )
 
     @par.parallelize(method=par.Method.THREAD)
-    def read_files_from_list(
-            self,
-            files: List[str],
-            n: int
-    ) -> Tuple[List, List]:
+    def read_files_from_list(self, files: List[str], n: int) -> Tuple[List, List]:
         # All features
         X = []
 
@@ -223,13 +214,17 @@ class TrainingBlueprint(ABC):
                 try:
                     (key, value) = line.split(self.spec.key_value_delimiter)
                 except Exception:
-                    logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
+                    logger.debug(
+                        f"WARNING: bad line in file {filename} '{line}', skipped"
+                    )
                     continue
 
                 key = key.strip()
                 value = value.strip()
-                if (self.spec.features_to_skip is not None
-                        and key in self.spec.features_to_skip):
+                if (
+                    self.spec.features_to_skip is not None
+                    and key in self.spec.features_to_skip
+                ):
                     logger.debug(f"Skipping feature {key}")
                     continue
 
@@ -263,10 +258,8 @@ class TrainingBlueprint(ABC):
     def make_progress_graph(self) -> None:
         if not self.spec.quiet:
             from text_utils import progress_graph
-            progress_graph(
-                self.file_done_count,
-                self.total_file_count
-            )
+
+            progress_graph(self.file_done_count, self.total_file_count)
 
     @timed
     def read_input_files(self):
@@ -315,9 +308,9 @@ class TrainingBlueprint(ABC):
             random_state=random.randrange(0, 1000),
         )
 
-    def scale_data(self,
-                   X_train: np.ndarray,
-                   X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
+    def scale_data(
+        self, X_train: np.ndarray, X_test: np.ndarray
+    ) -> Tuple[Any, np.ndarray, np.ndarray]:
         logger.debug("Scaling data")
         scaler = MinMaxScaler()
         scaler.fit(X_train)
@@ -325,19 +318,19 @@ class TrainingBlueprint(ABC):
 
     # Note: children should implement.  Consider using @parallelize.
     @abstractmethod
-    def train_model(self,
-                    parameters,
-                    X_train_scaled: np.ndarray,
-                    y_train: np.ndarray) -> Any:
+    def train_model(
+        self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray
+    ) -> Any:
         pass
 
     def evaluate_model(
-            self,
-            model: Any,
-            X_train_scaled: np.ndarray,
-            y_train: np.ndarray,
-            X_test_scaled: np.ndarray,
-            y_test: np.ndarray) -> Tuple[np.float64, np.float64]:
+        self,
+        model: Any,
+        X_train_scaled: np.ndarray,
+        y_train: np.ndarray,
+        X_test_scaled: np.ndarray,
+        y_test: np.ndarray,
+    ) -> Tuple[np.float64, np.float64]:
         logger.debug("Evaluating the model")
         training_score = model.score(X_train_scaled, y_train) * 100.0
         test_score = model.score(X_test_scaled, y_test) * 100.0
@@ -348,13 +341,14 @@ class TrainingBlueprint(ABC):
         return (training_score, test_score)
 
     def maybe_persist_scaler_and_model(
-            self,
-            training_score: np.float64,
-            test_score: np.float64,
-            params: str,
-            num_examples: int,
-            scaler: Any,
-            model: Any) -> Tuple[Optional[str], Optional[str], Optional[str]]:
+        self,
+        training_score: np.float64,
+        test_score: np.float64,
+        params: str,
+        num_examples: int,
+        scaler: Any,
+        model: Any,
+    ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
         if not self.spec.dry_run:
             import datetime_utils
             import input_utils
@@ -368,11 +362,11 @@ Training set score: {training_score:.2f}%
 Testing set score: {test_score:.2f}%"""
             print(f'\n{info}\n')
             if (
-                    (self.spec.persist_percentage_threshold is not None and
-                     test_score > self.spec.persist_percentage_threshold)
-                    or
-                    (not self.spec.quiet
-                     and input_utils.yn_response("Write the model? [y,n]: ") == "y")
+                self.spec.persist_percentage_threshold is not None
+                and test_score > self.spec.persist_percentage_threshold
+            ) or (
+                not self.spec.quiet
+                and input_utils.yn_response("Write the model? [y,n]: ") == "y"
             ):
                 scaler_filename = f"{self.spec.basename}_scaler.sav"
                 with open(scaler_filename, "wb") as f: