from __future__ import annotations
import concurrent.futures as fut
-import json
import logging
import os
import platform
import random
-import re
import subprocess
import threading
import time
@persistent.persistent_autoloaded_singleton() # type: ignore
-class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.Persistent):
- def __init__(self, remote_worker_pool: List[RemoteWorkerRecord]):
- self.remote_worker_pool = remote_worker_pool
+class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.JsonFileBasedPersistent):
+ def __init__(self, json_remote_worker_pool: Dict[str, Any]):
+ self.remote_worker_pool = []
+ for record in json_remote_worker_pool['remote_worker_records']:
+ self.remote_worker_pool.append(self.dataclassFromDict(RemoteWorkerRecord, record))
+ assert len(self.remote_worker_pool) > 0
+
+ @staticmethod
+ def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any:
+ fieldSet = {f.name for f in fields(clsName) if f.init}
+ filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
+ return clsName(**filteredArgDict)
@overrides
def get_remote_workers(self) -> List[RemoteWorkerRecord]:
return self.remote_worker_pool
- @staticmethod
- def dataclassFromDict(className, argDict: Dict[str, Any]) -> Any:
- fieldSet = {f.name for f in fields(className) if f.init}
- filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
- return className(**filteredArgDict)
-
- @classmethod
- def load(cls) -> List[RemoteWorkerRecord]:
- try:
- with open(config.config['remote_worker_records_file'], 'rb') as rf:
- lines = rf.readlines()
+ @overrides
+ def get_persistent_data(self) -> List[RemoteWorkerRecord]:
+ return self.remote_worker_pool
- buf = ''
- for line in lines:
- line = line.decode()
- line = re.sub(r'#.*$', '', line)
- buf += line
+ @staticmethod
+ @overrides
+ def get_filename() -> str:
+ return config.config['remote_worker_records_file']
- pool = []
- remote_worker_pool = json.loads(buf)
- for record in remote_worker_pool['remote_worker_records']:
- pool.append(cls.dataclassFromDict(RemoteWorkerRecord, record))
- return cls(pool)
- except Exception as e:
- raise Exception('Failed to parse JSON remote worker pool data.') from e
+ @staticmethod
+ @overrides
+ def should_we_load_data(filename: str) -> bool:
+ return True
+ @staticmethod
@overrides
- def save(self) -> bool:
- """We don't save the config; it should be edited by the user by hand."""
- pass
+ def should_we_save_data(filename: str) -> bool:
+ return False
@singleton
if record.machine == platform.node() and record.count > 1:
logger.info('Reducing workload for %s.', record.machine)
record.count = max(int(record.count / 2), 1)
- print(json.dumps(record.__dict__))
policy = WeightedRandomRemoteWorkerSelectionPolicy()
policy.register_worker_pool(pool)