Easier and more self documenting patterns for loading/saving Persistent master
authorScott Gasch <[email protected]>
Thu, 1 Sep 2022 18:42:34 +0000 (11:42 -0700)
committerScott Gasch <[email protected]>
Thu, 1 Sep 2022 18:42:34 +0000 (11:42 -0700)
state via pickle and json.

cached/weather_data.py
cached/weather_forecast.py
executors.py
persistent.py

index 91d665dbfd2e068ac2a10fc1ff867d552db3e71b..86978c858c02ca4b15c936c0af2f3b9b6a950078 100644 (file)
@@ -74,7 +74,7 @@ class WeatherData:
 
 
 @persistent.persistent_autoloaded_singleton()  # type: ignore
-class CachedWeatherData(persistent.Persistent):
+class CachedWeatherData(persistent.PicklingFileBasedPersistent):
     def __init__(self, weather_data: Dict[datetime.date, WeatherData] = None):
         """C'tor.  Do not pass a dict except for testing purposes.
 
@@ -207,42 +207,24 @@ class CachedWeatherData(persistent.Persistent):
                 icon=icon,
             )
 
-    @classmethod
+    @staticmethod
     @overrides
-    def load(cls) -> Any:
+    def get_filename() -> str:
+        return config.config['weather_data_cachefile']
 
-        """Depending on whether we have fresh data persisted either uses that
-        data to instantiate the class or makes an HTTP request to fetch the
-        necessary data.
-
-        Note that because this is a subclass of Persistent this is taken
-        care of automatically.
-        """
+    @staticmethod
+    @overrides
+    def should_we_save_data(filename: str) -> bool:
+        return True
 
-        if persistent.was_file_written_within_n_seconds(
-            config.config['weather_data_cachefile'],
+    @staticmethod
+    @overrides
+    def should_we_load_data(filename: str) -> bool:
+        return persistent.was_file_written_within_n_seconds(
+            filename,
             config.config['weather_data_stalest_acceptable'].total_seconds(),
-        ):
-            import pickle
-
-            with open(config.config['weather_data_cachefile'], 'rb') as rf:
-                weather_data = pickle.load(rf)
-                return cls(weather_data)
-        return None
+        )
 
     @overrides
-    def save(self) -> bool:
-        """
-        Saves the current data to disk if required.  Again, because this is
-        a subclass of Persistent this is taken care of for you.
-        """
-
-        import pickle
-
-        with open(config.config['weather_data_cachefile'], 'wb') as wf:
-            pickle.dump(
-                self.weather_data,
-                wf,
-                pickle.HIGHEST_PROTOCOL,
-            )
-        return True
+    def get_persistent_data(self) -> Any:
+        return self.weather_data
index b8a20ed8caa04553f01225a5a2f8f57866858ea7..7973bbb4c8b38b185042446efb82cfc2c72e130b 100644 (file)
@@ -56,7 +56,7 @@ class WeatherForecast:
 
 
 @persistent.persistent_autoloaded_singleton()  # type: ignore
-class CachedDetailedWeatherForecast(persistent.Persistent):
+class CachedDetailedWeatherForecast(persistent.PicklingFileBasedPersistent):
     def __init__(self, forecasts=None):
         if forecasts is not None:
             self.forecasts = forecasts
@@ -119,28 +119,24 @@ class CachedDetailedWeatherForecast(persistent.Persistent):
                     description=blurb,
                 )
 
-    @classmethod
+    @staticmethod
     @overrides
-    def load(cls) -> Any:
-        if persistent.was_file_written_within_n_seconds(
-            config.config['weather_forecast_cachefile'],
-            config.config['weather_forecast_stalest_acceptable'].total_seconds(),
-        ):
-            import pickle
-
-            with open(config.config['weather_forecast_cachefile'], 'rb') as rf:
-                weather_data = pickle.load(rf)
-                return cls(weather_data)
-        return None
+    def get_filename() -> str:
+        return config.config['weather_forecast_cachefile']
 
+    @staticmethod
     @overrides
-    def save(self) -> bool:
-        import pickle
-
-        with open(config.config['weather_forecast_cachefile'], 'wb') as wf:
-            pickle.dump(
-                self.forecasts,
-                wf,
-                pickle.HIGHEST_PROTOCOL,
-            )
+    def should_we_save_data(filename: str) -> bool:
         return True
+
+    @staticmethod
+    @overrides
+    def should_we_load_data(filename: str) -> bool:
+        return persistent.was_file_written_within_n_seconds(
+            filename,
+            config.config['weather_forecast_stalest_acceptable'].total_seconds(),
+        )
+
+    @overrides
+    def get_persistent_data(self) -> Any:
+        return self.forecasts
index cce0870ff4430dd418e5912b4c7ee43d881af903..2794ca18f6667fef64097272abd3bc4f58896298 100644 (file)
@@ -13,12 +13,10 @@ global executors / worker pools with automatic shutdown semantics."""
 
 from __future__ import annotations
 import concurrent.futures as fut
-import json
 import logging
 import os
 import platform
 import random
-import re
 import subprocess
 import threading
 import time
@@ -1350,44 +1348,41 @@ class RemoteWorkerPoolProvider:
 
 
 @persistent.persistent_autoloaded_singleton()  # type: ignore
-class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.Persistent):
-    def __init__(self, remote_worker_pool: List[RemoteWorkerRecord]):
-        self.remote_worker_pool = remote_worker_pool
+class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.JsonFileBasedPersistent):
+    def __init__(self, json_remote_worker_pool: Dict[str, Any]):
+        self.remote_worker_pool = []
+        for record in json_remote_worker_pool['remote_worker_records']:
+            self.remote_worker_pool.append(self.dataclassFromDict(RemoteWorkerRecord, record))
+        assert len(self.remote_worker_pool) > 0
+
+    @staticmethod
+    def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any:
+        fieldSet = {f.name for f in fields(clsName) if f.init}
+        filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
+        return clsName(**filteredArgDict)
 
     @overrides
     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
         return self.remote_worker_pool
 
-    @staticmethod
-    def dataclassFromDict(className, argDict: Dict[str, Any]) -> Any:
-        fieldSet = {f.name for f in fields(className) if f.init}
-        filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
-        return className(**filteredArgDict)
-
-    @classmethod
-    def load(cls) -> List[RemoteWorkerRecord]:
-        try:
-            with open(config.config['remote_worker_records_file'], 'rb') as rf:
-                lines = rf.readlines()
+    @overrides
+    def get_persistent_data(self) -> List[RemoteWorkerRecord]:
+        return self.remote_worker_pool
 
-            buf = ''
-            for line in lines:
-                line = line.decode()
-                line = re.sub(r'#.*$', '', line)
-                buf += line
+    @staticmethod
+    @overrides
+    def get_filename() -> str:
+        return config.config['remote_worker_records_file']
 
-            pool = []
-            remote_worker_pool = json.loads(buf)
-            for record in remote_worker_pool['remote_worker_records']:
-                pool.append(cls.dataclassFromDict(RemoteWorkerRecord, record))
-            return cls(pool)
-        except Exception as e:
-            raise Exception('Failed to parse JSON remote worker pool data.') from e
+    @staticmethod
+    @overrides
+    def should_we_load_data(filename: str) -> bool:
+        return True
 
+    @staticmethod
     @overrides
-    def save(self) -> bool:
-        """We don't save the config; it should be edited by the user by hand."""
-        pass
+    def should_we_save_data(filename: str) -> bool:
+        return False
 
 
 @singleton
@@ -1471,7 +1466,6 @@ class DefaultExecutors(object):
                 if record.machine == platform.node() and record.count > 1:
                     logger.info('Reducing workload for %s.', record.machine)
                     record.count = max(int(record.count / 2), 1)
-                print(json.dumps(record.__dict__))
 
             policy = WeightedRandomRemoteWorkerSelectionPolicy()
             policy.register_worker_pool(pool)
index 6cc444cbf85afb13fcf3fe6b1d06b4a7ee1cfca1..950471ec7e323a8d13facccc398d4b84f36cda63 100644 (file)
@@ -11,8 +11,11 @@ import datetime
 import enum
 import functools
 import logging
+import re
 from abc import ABC, abstractmethod
-from typing import Any
+from typing import Any, Optional
+
+from overrides import overrides
 
 import file_utils
 
@@ -59,6 +62,113 @@ class Persistent(ABC):
         pass
 
 
+class FileBasedPersistent(Persistent):
+    """A Persistent that uses a file to save/load data and knows the conditions
+    under which the state should be saved/loaded."""
+
+    @staticmethod
+    @abstractmethod
+    def get_filename() -> str:
+        """Since this class saves/loads to/from a file, what's its full path?"""
+        pass
+
+    @staticmethod
+    @abstractmethod
+    def should_we_save_data(filename: str) -> bool:
+        pass
+
+    @staticmethod
+    @abstractmethod
+    def should_we_load_data(filename: str) -> bool:
+        pass
+
+    @abstractmethod
+    def get_persistent_data(self) -> Any:
+        pass
+
+
+class PicklingFileBasedPersistent(FileBasedPersistent):
+    @classmethod
+    @overrides
+    def load(cls) -> Optional[Any]:
+        filename = cls.get_filename()
+        if cls.should_we_load_data(filename):
+            logger.debug('Attempting to load state from %s', filename)
+
+            import pickle
+
+            try:
+                with open(filename, 'rb') as rf:
+                    data = pickle.load(rf)
+                    return cls(data)
+
+            except Exception as e:
+                raise Exception(f'Failed to load {filename}.') from e
+        return None
+
+    @overrides
+    def save(self) -> bool:
+        filename = self.get_filename()
+        if self.should_we_save_data(filename):
+            logger.debug('Trying to save state in %s', filename)
+            try:
+                import pickle
+
+                with open(filename, 'wb') as wf:
+                    pickle.dump(self.get_persistent_data(), wf, pickle.HIGHEST_PROTOCOL)
+                return True
+            except Exception as e:
+                raise Exception(f'Failed to save to {filename}.') from e
+        return False
+
+
+class JsonFileBasedPersistent(FileBasedPersistent):
+    @classmethod
+    @overrides
+    def load(cls) -> Any:
+        filename = cls.get_filename()
+        if cls.should_we_load_data(filename):
+            logger.debug('Trying to load state from %s', filename)
+            import json
+
+            try:
+                with open(filename, 'r') as rf:
+                    lines = rf.readlines()
+
+                # This is probably bad... but I like comments
+                # in config files and JSON doesn't support them.  So
+                # pre-process the buffer to remove comments thus
+                # allowing people to add them.
+                buf = ''
+                for line in lines:
+                    line = re.sub(r'#.*$', '', line)
+                    buf += line
+
+                json_dict = json.loads(buf)
+                return cls(json_dict)
+
+            except Exception as e:
+                logger.exception(e)
+                raise Exception(f'Failed to load {filename}.') from e
+        return None
+
+    @overrides
+    def save(self) -> bool:
+        filename = self.get_filename()
+        if self.should_we_save_data(filename):
+            logger.debug('Trying to save state in %s', filename)
+            try:
+                import json
+
+                json_blob = json.dumps(self.get_persistent_data())
+                with open(filename, 'w') as wf:
+                    wf.writelines(json_blob)
+                return True
+            except Exception as e:
+                raise Exception(f'Failed to save to {filename}.') from e
+        return False
+
+
 def was_file_written_today(filename: str) -> bool:
     """Convenience wrapper around :meth:`was_file_written_within_n_seconds`.
 
@@ -225,10 +335,6 @@ class persistent_autoloaded_singleton(object):
         return _load
 
 
-# TODO: PicklingPersistant?
-# TODO: JsonConfigPersistant?
-
-
 if __name__ == '__main__':
     import doctest