Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / executors.py
index cce0870ff4430dd418e5912b4c7ee43d881af903..2794ca18f6667fef64097272abd3bc4f58896298 100644 (file)
@@ -13,12 +13,10 @@ global executors / worker pools with automatic shutdown semantics."""
 
 from __future__ import annotations
 import concurrent.futures as fut
-import json
 import logging
 import os
 import platform
 import random
-import re
 import subprocess
 import threading
 import time
@@ -1350,44 +1348,41 @@ class RemoteWorkerPoolProvider:
 
 
 @persistent.persistent_autoloaded_singleton()  # type: ignore
-class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.Persistent):
-    def __init__(self, remote_worker_pool: List[RemoteWorkerRecord]):
-        self.remote_worker_pool = remote_worker_pool
+class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.JsonFileBasedPersistent):
+    def __init__(self, json_remote_worker_pool: Dict[str, Any]):
+        self.remote_worker_pool = []
+        for record in json_remote_worker_pool['remote_worker_records']:
+            self.remote_worker_pool.append(self.dataclassFromDict(RemoteWorkerRecord, record))
+        assert len(self.remote_worker_pool) > 0
+
+    @staticmethod
+    def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any:
+        fieldSet = {f.name for f in fields(clsName) if f.init}
+        filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
+        return clsName(**filteredArgDict)
 
     @overrides
     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
         return self.remote_worker_pool
 
-    @staticmethod
-    def dataclassFromDict(className, argDict: Dict[str, Any]) -> Any:
-        fieldSet = {f.name for f in fields(className) if f.init}
-        filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
-        return className(**filteredArgDict)
-
-    @classmethod
-    def load(cls) -> List[RemoteWorkerRecord]:
-        try:
-            with open(config.config['remote_worker_records_file'], 'rb') as rf:
-                lines = rf.readlines()
+    @overrides
+    def get_persistent_data(self) -> List[RemoteWorkerRecord]:
+        return self.remote_worker_pool
 
-            buf = ''
-            for line in lines:
-                line = line.decode()
-                line = re.sub(r'#.*$', '', line)
-                buf += line
+    @staticmethod
+    @overrides
+    def get_filename() -> str:
+        return config.config['remote_worker_records_file']
 
-            pool = []
-            remote_worker_pool = json.loads(buf)
-            for record in remote_worker_pool['remote_worker_records']:
-                pool.append(cls.dataclassFromDict(RemoteWorkerRecord, record))
-            return cls(pool)
-        except Exception as e:
-            raise Exception('Failed to parse JSON remote worker pool data.') from e
+    @staticmethod
+    @overrides
+    def should_we_load_data(filename: str) -> bool:
+        return True
 
+    @staticmethod
     @overrides
-    def save(self) -> bool:
-        """We don't save the config; it should be edited by the user by hand."""
-        pass
+    def should_we_save_data(filename: str) -> bool:
+        return False
 
 
 @singleton
@@ -1471,7 +1466,6 @@ class DefaultExecutors(object):
                 if record.machine == platform.node() and record.count > 1:
                     logger.info('Reducing workload for %s.', record.machine)
                     record.count = max(int(record.count / 2), 1)
-                print(json.dumps(record.__dict__))
 
             policy = WeightedRandomRemoteWorkerSelectionPolicy()
             policy.register_worker_pool(pool)