from __future__ import annotations
import concurrent.futures as fut
+import json
import logging
import os
import platform
import random
+import re
import subprocess
import threading
import time
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
-from dataclasses import dataclass
+from dataclasses import dataclass, fields
from typing import Any, Callable, Dict, List, Optional, Set
import cloudpickle # type: ignore
import argparse_utils
import config
import histogram as hist
+import persistent
import string_utils
from ansi import bg, fg, reset, underline
from decorator_utils import singleton
metavar='#FAILURES',
help='Maximum number of failures before giving up on a bundle',
)
+parser.add_argument(
+ '--remote_worker_records_file',
+ type=str,
+ metavar='FILENAME',
+ help='Path of the remote worker records file (JSON)',
+ default=f'{os.environ.get("HOME", ".")}/.remote_worker_records',
+)
+
SSH = '/usr/bin/ssh -oForwardX11=no'
SCP = '/usr/bin/scp -C'
self.already_shutdown = True
+class RemoteWorkerPoolProvider:
+ @abstractmethod
+ def get_remote_workers(self) -> List[RemoteWorkerRecord]:
+ pass
+
+
+class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.Persistent):
+ def __init__(self, remote_worker_pool: List[RemoteWorkerRecord]):
+ self.remote_worker_pool = remote_worker_pool
+
+ @overrides
+ def get_remote_workers(self) -> List[RemoteWorkerRecord]:
+ return self.remote_worker_pool
+
+ @staticmethod
+ def dataclassFromDict(className, argDict: Dict[str, Any]) -> Any:
+ fieldSet = {f.name for f in fields(className) if f.init}
+ filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
+ return className(**filteredArgDict)
+
+ @classmethod
+ def load(cls) -> List[RemoteWorkerRecord]:
+ try:
+ with open(config.config['remote_worker_records_file'], 'rb') as rf:
+ lines = rf.readlines()
+
+ buf = ''
+ for line in lines:
+ line = line.decode()
+ line = re.sub(r'#.*$', '', line)
+ buf += line
+
+ pool = []
+ remote_worker_pool = json.loads(buf)
+ for record in remote_worker_pool['remote_worker_records']:
+ pool.append(cls.dataclassFromDict(RemoteWorkerRecord, record))
+ return cls(pool)
+ except Exception as e:
+ raise Exception('Failed to parse JSON remote worker pool data.') from e
+
+ @overrides
+ def save(self) -> bool:
+ """We don't save the config; it should be edited by the user by hand."""
+ pass
+
+
@singleton
class DefaultExecutors(object):
"""A container for a default thread, process and remote executor.
def remote_pool(self) -> RemoteExecutor:
if self.remote_executor is None:
logger.info('Looking for some helper machines...')
- pool: List[RemoteWorkerRecord] = []
- if self._ping('cheetah.house'):
- logger.info('Found cheetah.house')
- pool.append(
- RemoteWorkerRecord(
- username='scott',
- machine='cheetah.house',
- weight=24,
- count=5,
- ),
- )
- if self._ping('meerkat.cabin'):
- logger.info('Found meerkat.cabin')
- pool.append(
- RemoteWorkerRecord(
- username='scott',
- machine='meerkat.cabin',
- weight=12,
- count=2,
- ),
- )
- if self._ping('wannabe.house'):
- logger.info('Found wannabe.house')
- pool.append(
- RemoteWorkerRecord(
- username='scott',
- machine='wannabe.house',
- weight=14,
- count=2,
- ),
- )
- if self._ping('puma.cabin'):
- logger.info('Found puma.cabin')
- pool.append(
- RemoteWorkerRecord(
- username='scott',
- machine='puma.cabin',
- weight=24,
- count=5,
- ),
- )
- if self._ping('backup.house'):
- logger.info('Found backup.house')
- pool.append(
- RemoteWorkerRecord(
- username='scott',
- machine='backup.house',
- weight=9,
- count=2,
- ),
- )
+ provider = ConfigRemoteWorkerPoolProvider()
+ all_machines = provider.get_remote_workers()
+ pool = []
+
+ # Make sure we can ping each machine.
+ for record in all_machines:
+ if self._ping(record.machine):
+ logger.info('%s is alive / responding to pings', record.machine)
+ pool.append(record)
# The controller machine has a lot to do; go easy on it.
for record in pool:
if record.machine == platform.node() and record.count > 1:
logger.info('Reducing workload for %s.', record.machine)
record.count = max(int(record.count / 2), 1)
+ print(json.dumps(record.__dict__))
policy = WeightedRandomRemoteWorkerSelectionPolicy()
policy.register_worker_pool(pool)