Get rid of hardcoded remote worker pool in favor of a config file.
authorScott Gasch <[email protected]>
Thu, 1 Sep 2022 17:21:11 +0000 (10:21 -0700)
committerScott Gasch <[email protected]>
Thu, 1 Sep 2022 17:21:11 +0000 (10:21 -0700)
executors.py
persistent.py

index 6485afa054689c3b668adb1e0708b7f2d29ed8b9..cce0870ff4430dd418e5912b4c7ee43d881af903 100644 (file)
@@ -13,17 +13,19 @@ global executors / worker pools with automatic shutdown semantics."""
 
 from __future__ import annotations
 import concurrent.futures as fut
+import json
 import logging
 import os
 import platform
 import random
+import re
 import subprocess
 import threading
 import time
 import warnings
 from abc import ABC, abstractmethod
 from collections import defaultdict
-from dataclasses import dataclass
+from dataclasses import dataclass, fields
 from typing import Any, Callable, Dict, List, Optional, Set
 
 import cloudpickle  # type: ignore
@@ -33,6 +35,7 @@ from overrides import overrides
 import argparse_utils
 import config
 import histogram as hist
+import persistent
 import string_utils
 from ansi import bg, fg, reset, underline
 from decorator_utils import singleton
@@ -71,6 +74,14 @@ parser.add_argument(
     metavar='#FAILURES',
     help='Maximum number of failures before giving up on a bundle',
 )
+parser.add_argument(
+    '--remote_worker_records_file',
+    type=str,
+    metavar='FILENAME',
+    help='Path of the remote worker records file (JSON)',
+    default=f'{os.environ.get("HOME", ".")}/.remote_worker_records',
+)
+
 
 SSH = '/usr/bin/ssh -oForwardX11=no'
 SCP = '/usr/bin/scp -C'
@@ -1332,6 +1343,53 @@ class RemoteExecutor(BaseExecutor):
             self.already_shutdown = True
 
 
+class RemoteWorkerPoolProvider:
+    @abstractmethod
+    def get_remote_workers(self) -> List[RemoteWorkerRecord]:
+        pass
+
+
[email protected]_autoloaded_singleton()  # type: ignore
+class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.Persistent):
+    def __init__(self, remote_worker_pool: List[RemoteWorkerRecord]):
+        self.remote_worker_pool = remote_worker_pool
+
+    @overrides
+    def get_remote_workers(self) -> List[RemoteWorkerRecord]:
+        return self.remote_worker_pool
+
+    @staticmethod
+    def dataclassFromDict(className, argDict: Dict[str, Any]) -> Any:
+        fieldSet = {f.name for f in fields(className) if f.init}
+        filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
+        return className(**filteredArgDict)
+
+    @classmethod
+    def load(cls) -> List[RemoteWorkerRecord]:
+        try:
+            with open(config.config['remote_worker_records_file'], 'rb') as rf:
+                lines = rf.readlines()
+
+            buf = ''
+            for line in lines:
+                line = line.decode()
+                line = re.sub(r'#.*$', '', line)
+                buf += line
+
+            pool = []
+            remote_worker_pool = json.loads(buf)
+            for record in remote_worker_pool['remote_worker_records']:
+                pool.append(cls.dataclassFromDict(RemoteWorkerRecord, record))
+            return cls(pool)
+        except Exception as e:
+            raise Exception('Failed to parse JSON remote worker pool data.') from e
+
+    @overrides
+    def save(self) -> bool:
+        """We don't save the config; it should be edited by the user by hand."""
+        pass
+
+
 @singleton
 class DefaultExecutors(object):
     """A container for a default thread, process and remote executor.
@@ -1398,63 +1456,22 @@ class DefaultExecutors(object):
     def remote_pool(self) -> RemoteExecutor:
         if self.remote_executor is None:
             logger.info('Looking for some helper machines...')
-            pool: List[RemoteWorkerRecord] = []
-            if self._ping('cheetah.house'):
-                logger.info('Found cheetah.house')
-                pool.append(
-                    RemoteWorkerRecord(
-                        username='scott',
-                        machine='cheetah.house',
-                        weight=24,
-                        count=5,
-                    ),
-                )
-            if self._ping('meerkat.cabin'):
-                logger.info('Found meerkat.cabin')
-                pool.append(
-                    RemoteWorkerRecord(
-                        username='scott',
-                        machine='meerkat.cabin',
-                        weight=12,
-                        count=2,
-                    ),
-                )
-            if self._ping('wannabe.house'):
-                logger.info('Found wannabe.house')
-                pool.append(
-                    RemoteWorkerRecord(
-                        username='scott',
-                        machine='wannabe.house',
-                        weight=14,
-                        count=2,
-                    ),
-                )
-            if self._ping('puma.cabin'):
-                logger.info('Found puma.cabin')
-                pool.append(
-                    RemoteWorkerRecord(
-                        username='scott',
-                        machine='puma.cabin',
-                        weight=24,
-                        count=5,
-                    ),
-                )
-            if self._ping('backup.house'):
-                logger.info('Found backup.house')
-                pool.append(
-                    RemoteWorkerRecord(
-                        username='scott',
-                        machine='backup.house',
-                        weight=9,
-                        count=2,
-                    ),
-                )
+            provider = ConfigRemoteWorkerPoolProvider()
+            all_machines = provider.get_remote_workers()
+            pool = []
+
+            # Make sure we can ping each machine.
+            for record in all_machines:
+                if self._ping(record.machine):
+                    logger.info('%s is alive / responding to pings', record.machine)
+                    pool.append(record)
 
             # The controller machine has a lot to do; go easy on it.
             for record in pool:
                 if record.machine == platform.node() and record.count > 1:
                     logger.info('Reducing workload for %s.', record.machine)
                     record.count = max(int(record.count / 2), 1)
+                print(json.dumps(record.__dict__))
 
             policy = WeightedRandomRemoteWorkerSelectionPolicy()
             policy.register_worker_pool(pool)
index 58014608f9a894fa40742249cd9c04933007d52e..6cc444cbf85afb13fcf3fe6b1d06b4a7ee1cfca1 100644 (file)
@@ -225,6 +225,10 @@ class persistent_autoloaded_singleton(object):
         return _load
 
 
+# TODO: PicklingPersistant?
+# TODO: JsonConfigPersistant?
+
+
 if __name__ == '__main__':
     import doctest