Bug and readability fix.
[pyutils.git] / src / pyutils / parallelize / executors.py
index 23fd6eb96c3a9a3198691c705c455902d982fce8..e8f8d40dd30b1eb1ca52c95d0ab282e43e17ea79 100644 (file)
@@ -49,18 +49,26 @@ import time
 import warnings
 from abc import ABC, abstractmethod
 from collections import defaultdict
-from dataclasses import dataclass, fields
+from dataclasses import dataclass
 from typing import Any, Callable, Dict, List, Optional, Set
 
 import cloudpickle  # type: ignore
 from overrides import overrides
 
-import pyutils.typez.histogram as hist
-from pyutils import argparse_utils, config, math_utils, persistent, string_utils
+import pyutils.types.histogram as hist
+from pyutils import (
+    argparse_utils,
+    config,
+    dataclass_utils,
+    math_utils,
+    persistent,
+    string_utils,
+)
 from pyutils.ansi import bg, fg, reset, underline
 from pyutils.decorator_utils import singleton
 from pyutils.exec_utils import cmd_exitcode, cmd_in_background, run_silently
 from pyutils.parallelize.thread_utils import background_thread
+from pyutils.types import type_utils
 
 logger = logging.getLogger(__name__)
 
@@ -106,7 +114,7 @@ parser.add_argument(
     type=str,
     metavar='PATH_TO_REMOTE_WORKER_PY',
     help='Path to remote_worker.py on remote machines',
-    default='source py39-venv/bin/activate && /home/scott/lib/release/pyutils/src/pyutils/remote_worker.py',
+    default=f'source py39-venv/bin/activate && {os.environ["HOME"]}/pyutils/src/pyutils/remote_worker.py',
 )
 
 
@@ -385,7 +393,7 @@ class BundleDetails:
     machine: Optional[str]
     """The remote machine running this bundle or None if none (yet)"""
 
-    hostname: str
+    controller: str
     """The controller machine"""
 
     code_file: str
@@ -580,7 +588,8 @@ class RemoteExecutorStatus:
         total_finished = len(self.finished_bundle_timings)
         total_in_flight = self.total_in_flight()
         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
-        qall = None
+        qall_median = None
+        qall_p95 = None
         if len(self.finished_bundle_timings) > 1:
             qall_median = self.finished_bundle_timings.get_median()
             qall_p95 = self.finished_bundle_timings.get_percentile(95)
@@ -634,8 +643,8 @@ class RemoteExecutorStatus:
                             if details is not None:
                                 details.slower_than_local_p95 = False
 
-                    if qall is not None:
-                        if sec > qall[1]:
+                    if qall_p95 is not None:
+                        if sec > qall_p95:
                             ret += f'{bg("red")}>∀p95{reset()} '
                             if details is not None:
                                 details.slower_than_global_p95 = True
@@ -1059,7 +1068,7 @@ class RemoteExecutor(BaseExecutor):
 
         self.adjust_task_count(+1)
         uuid = bundle.uuid
-        hostname = bundle.hostname
+        controller = bundle.controller
         avoid_machine = override_avoid_machine
         is_original = bundle.src_bundle is None
 
@@ -1113,7 +1122,7 @@ class RemoteExecutor(BaseExecutor):
                     return None
 
         # Send input code / data to worker machine if it's not local.
-        if hostname not in machine:
+        if controller not in machine:
             try:
                 cmd = (
                     f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
@@ -1266,7 +1275,7 @@ class RemoteExecutor(BaseExecutor):
             bundle.end_ts = time.time()
             if not was_cancelled:
                 assert bundle.machine is not None
-                if bundle.hostname not in bundle.machine:
+                if bundle.controller not in bundle.machine:
                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
                     logger.info(
                         "%s: Fetching results back from %s@%s via %s",
@@ -1379,7 +1388,7 @@ class RemoteExecutor(BaseExecutor):
             worker=None,
             username=None,
             machine=None,
-            hostname=platform.node(),
+            controller=platform.node(),
             code_file=code_file,
             result_file=result_file,
             pid=0,
@@ -1413,7 +1422,7 @@ class RemoteExecutor(BaseExecutor):
             worker=None,
             username=None,
             machine=None,
-            hostname=src_bundle.hostname,
+            controller=src_bundle.controller,
             code_file=src_bundle.code_file,
             result_file=src_bundle.result_file,
             pid=0,
@@ -1477,11 +1486,10 @@ class RemoteExecutor(BaseExecutor):
                 raise RemoteExecutorException(
                     f'{bundle}: This bundle can\'t be completed despite several backups and retries',
                 )
-            else:
-                logger.error(
-                    '%s: At least it\'s only a backup; better luck with the others.',
-                    bundle,
-                )
+            logger.error(
+                '%s: At least it\'s only a backup; better luck with the others.',
+                bundle,
+            )
             return None
         else:
             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
@@ -1527,16 +1535,10 @@ class ConfigRemoteWorkerPoolProvider(
         self.remote_worker_pool = []
         for record in json_remote_worker_pool['remote_worker_records']:
             self.remote_worker_pool.append(
-                self.dataclassFromDict(RemoteWorkerRecord, record)
+                dataclass_utils.dataclass_from_dict(RemoteWorkerRecord, record)
             )
         assert len(self.remote_worker_pool) > 0
 
-    @staticmethod
-    def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any:
-        fieldSet = {f.name for f in fields(clsName) if f.init}
-        filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
-        return clsName(**filteredArgDict)
-
     @overrides
     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
         return self.remote_worker_pool
@@ -1548,7 +1550,7 @@ class ConfigRemoteWorkerPoolProvider(
     @staticmethod
     @overrides
     def get_filename() -> str:
-        return config.config['remote_worker_records_file']
+        return type_utils.unwrap_optional(config.config['remote_worker_records_file'])
 
     @staticmethod
     @overrides