From: Scott Gasch Date: Thu, 1 Sep 2022 18:42:34 +0000 (-0700) Subject: Easier and more self documenting patterns for loading/saving Persistent X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=HEAD;p=python_utils.git Easier and more self documenting patterns for loading/saving Persistent state via pickle and json. --- diff --git a/cached/weather_data.py b/cached/weather_data.py index 91d665d..86978c8 100644 --- a/cached/weather_data.py +++ b/cached/weather_data.py @@ -74,7 +74,7 @@ class WeatherData: @persistent.persistent_autoloaded_singleton() # type: ignore -class CachedWeatherData(persistent.Persistent): +class CachedWeatherData(persistent.PicklingFileBasedPersistent): def __init__(self, weather_data: Dict[datetime.date, WeatherData] = None): """C'tor. Do not pass a dict except for testing purposes. @@ -207,42 +207,24 @@ class CachedWeatherData(persistent.Persistent): icon=icon, ) - @classmethod + @staticmethod @overrides - def load(cls) -> Any: + def get_filename() -> str: + return config.config['weather_data_cachefile'] - """Depending on whether we have fresh data persisted either uses that - data to instantiate the class or makes an HTTP request to fetch the - necessary data. - - Note that because this is a subclass of Persistent this is taken - care of automatically. - """ + @staticmethod + @overrides + def should_we_save_data(filename: str) -> bool: + return True - if persistent.was_file_written_within_n_seconds( - config.config['weather_data_cachefile'], + @staticmethod + @overrides + def should_we_load_data(filename: str) -> bool: + return persistent.was_file_written_within_n_seconds( + filename, config.config['weather_data_stalest_acceptable'].total_seconds(), - ): - import pickle - - with open(config.config['weather_data_cachefile'], 'rb') as rf: - weather_data = pickle.load(rf) - return cls(weather_data) - return None + ) @overrides - def save(self) -> bool: - """ - Saves the current data to disk if required. Again, because this is - a subclass of Persistent this is taken care of for you. - """ - - import pickle - - with open(config.config['weather_data_cachefile'], 'wb') as wf: - pickle.dump( - self.weather_data, - wf, - pickle.HIGHEST_PROTOCOL, - ) - return True + def get_persistent_data(self) -> Any: + return self.weather_data diff --git a/cached/weather_forecast.py b/cached/weather_forecast.py index b8a20ed..7973bbb 100644 --- a/cached/weather_forecast.py +++ b/cached/weather_forecast.py @@ -56,7 +56,7 @@ class WeatherForecast: @persistent.persistent_autoloaded_singleton() # type: ignore -class CachedDetailedWeatherForecast(persistent.Persistent): +class CachedDetailedWeatherForecast(persistent.PicklingFileBasedPersistent): def __init__(self, forecasts=None): if forecasts is not None: self.forecasts = forecasts @@ -119,28 +119,24 @@ class CachedDetailedWeatherForecast(persistent.Persistent): description=blurb, ) - @classmethod + @staticmethod @overrides - def load(cls) -> Any: - if persistent.was_file_written_within_n_seconds( - config.config['weather_forecast_cachefile'], - config.config['weather_forecast_stalest_acceptable'].total_seconds(), - ): - import pickle - - with open(config.config['weather_forecast_cachefile'], 'rb') as rf: - weather_data = pickle.load(rf) - return cls(weather_data) - return None + def get_filename() -> str: + return config.config['weather_forecast_cachefile'] + @staticmethod @overrides - def save(self) -> bool: - import pickle - - with open(config.config['weather_forecast_cachefile'], 'wb') as wf: - pickle.dump( - self.forecasts, - wf, - pickle.HIGHEST_PROTOCOL, - ) + def should_we_save_data(filename: str) -> bool: return True + + @staticmethod + @overrides + def should_we_load_data(filename: str) -> bool: + return persistent.was_file_written_within_n_seconds( + filename, + config.config['weather_forecast_stalest_acceptable'].total_seconds(), + ) + + @overrides + def get_persistent_data(self) -> Any: + return self.forecasts diff --git a/executors.py b/executors.py index cce0870..2794ca1 100644 --- a/executors.py +++ b/executors.py @@ -13,12 +13,10 @@ global executors / worker pools with automatic shutdown semantics.""" from __future__ import annotations import concurrent.futures as fut -import json import logging import os import platform import random -import re import subprocess import threading import time @@ -1350,44 +1348,41 @@ class RemoteWorkerPoolProvider: @persistent.persistent_autoloaded_singleton() # type: ignore -class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.Persistent): - def __init__(self, remote_worker_pool: List[RemoteWorkerRecord]): - self.remote_worker_pool = remote_worker_pool +class ConfigRemoteWorkerPoolProvider(RemoteWorkerPoolProvider, persistent.JsonFileBasedPersistent): + def __init__(self, json_remote_worker_pool: Dict[str, Any]): + self.remote_worker_pool = [] + for record in json_remote_worker_pool['remote_worker_records']: + self.remote_worker_pool.append(self.dataclassFromDict(RemoteWorkerRecord, record)) + assert len(self.remote_worker_pool) > 0 + + @staticmethod + def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any: + fieldSet = {f.name for f in fields(clsName) if f.init} + filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet} + return clsName(**filteredArgDict) @overrides def get_remote_workers(self) -> List[RemoteWorkerRecord]: return self.remote_worker_pool - @staticmethod - def dataclassFromDict(className, argDict: Dict[str, Any]) -> Any: - fieldSet = {f.name for f in fields(className) if f.init} - filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet} - return className(**filteredArgDict) - - @classmethod - def load(cls) -> List[RemoteWorkerRecord]: - try: - with open(config.config['remote_worker_records_file'], 'rb') as rf: - lines = rf.readlines() + @overrides + def get_persistent_data(self) -> List[RemoteWorkerRecord]: + return self.remote_worker_pool - buf = '' - for line in lines: - line = line.decode() - line = re.sub(r'#.*$', '', line) - buf += line + @staticmethod + @overrides + def get_filename() -> str: + return config.config['remote_worker_records_file'] - pool = [] - remote_worker_pool = json.loads(buf) - for record in remote_worker_pool['remote_worker_records']: - pool.append(cls.dataclassFromDict(RemoteWorkerRecord, record)) - return cls(pool) - except Exception as e: - raise Exception('Failed to parse JSON remote worker pool data.') from e + @staticmethod + @overrides + def should_we_load_data(filename: str) -> bool: + return True + @staticmethod @overrides - def save(self) -> bool: - """We don't save the config; it should be edited by the user by hand.""" - pass + def should_we_save_data(filename: str) -> bool: + return False @singleton @@ -1471,7 +1466,6 @@ class DefaultExecutors(object): if record.machine == platform.node() and record.count > 1: logger.info('Reducing workload for %s.', record.machine) record.count = max(int(record.count / 2), 1) - print(json.dumps(record.__dict__)) policy = WeightedRandomRemoteWorkerSelectionPolicy() policy.register_worker_pool(pool) diff --git a/persistent.py b/persistent.py index 6cc444c..950471e 100644 --- a/persistent.py +++ b/persistent.py @@ -11,8 +11,11 @@ import datetime import enum import functools import logging +import re from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Optional + +from overrides import overrides import file_utils @@ -59,6 +62,113 @@ class Persistent(ABC): pass +class FileBasedPersistent(Persistent): + """A Persistent that uses a file to save/load data and knows the conditions + under which the state should be saved/loaded.""" + + @staticmethod + @abstractmethod + def get_filename() -> str: + """Since this class saves/loads to/from a file, what's its full path?""" + pass + + @staticmethod + @abstractmethod + def should_we_save_data(filename: str) -> bool: + pass + + @staticmethod + @abstractmethod + def should_we_load_data(filename: str) -> bool: + pass + + @abstractmethod + def get_persistent_data(self) -> Any: + pass + + +class PicklingFileBasedPersistent(FileBasedPersistent): + @classmethod + @overrides + def load(cls) -> Optional[Any]: + filename = cls.get_filename() + if cls.should_we_load_data(filename): + logger.debug('Attempting to load state from %s', filename) + + import pickle + + try: + with open(filename, 'rb') as rf: + data = pickle.load(rf) + return cls(data) + + except Exception as e: + raise Exception(f'Failed to load {filename}.') from e + return None + + @overrides + def save(self) -> bool: + filename = self.get_filename() + if self.should_we_save_data(filename): + logger.debug('Trying to save state in %s', filename) + try: + import pickle + + with open(filename, 'wb') as wf: + pickle.dump(self.get_persistent_data(), wf, pickle.HIGHEST_PROTOCOL) + return True + except Exception as e: + raise Exception(f'Failed to save to {filename}.') from e + return False + + +class JsonFileBasedPersistent(FileBasedPersistent): + @classmethod + @overrides + def load(cls) -> Any: + filename = cls.get_filename() + if cls.should_we_load_data(filename): + logger.debug('Trying to load state from %s', filename) + import json + + try: + with open(filename, 'r') as rf: + lines = rf.readlines() + + # This is probably bad... but I like comments + # in config files and JSON doesn't support them. So + # pre-process the buffer to remove comments thus + # allowing people to add them. + buf = '' + for line in lines: + line = re.sub(r'#.*$', '', line) + buf += line + + json_dict = json.loads(buf) + return cls(json_dict) + + except Exception as e: + logger.exception(e) + raise Exception(f'Failed to load {filename}.') from e + return None + + @overrides + def save(self) -> bool: + filename = self.get_filename() + if self.should_we_save_data(filename): + logger.debug('Trying to save state in %s', filename) + try: + import json + + json_blob = json.dumps(self.get_persistent_data()) + with open(filename, 'w') as wf: + wf.writelines(json_blob) + return True + except Exception as e: + raise Exception(f'Failed to save to {filename}.') from e + return False + + def was_file_written_today(filename: str) -> bool: """Convenience wrapper around :meth:`was_file_written_within_n_seconds`. @@ -225,10 +335,6 @@ class persistent_autoloaded_singleton(object): return _load -# TODO: PicklingPersistant? -# TODO: JsonConfigPersistant? - - if __name__ == '__main__': import doctest