From: Scott Gasch Date: Wed, 26 Oct 2022 03:58:53 +0000 (-0700) Subject: Add dataclass_utils for some simple dataclass wrappers and annotation. X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=a789441d4abd1a97271b44d6f8d09925d6777dda;p=pyutils.git Add dataclass_utils for some simple dataclass wrappers and annotation. --- diff --git a/src/pyutils/dataclass_utils.py b/src/pyutils/dataclass_utils.py new file mode 100644 index 0000000..22d483e --- /dev/null +++ b/src/pyutils/dataclass_utils.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 + +"""Utilities for dealing with Dataclasses. A non-official type hint and some +friendly wrappers around conversion to/from Dicts.""" + +import dataclasses +from typing import Any, Dict, Protocol, Type + + +class Dataclass(Protocol): + """Dataclass isn't really a first class type and therefore there is no offical + type hint for Dataclasses in Python (yet). If you need one, here's a suitable + stand in. Example usage:: + + def f(d: Dataclass) -> Any: + pass + + def g(d: Dict[str, Any]) -> Dataclass: + pass + """ + + __dataclass_fields__: Dict + + +def dataclass_from_dict(dataclass: Type[Dataclass], d: Dict[str, Any]) -> Dataclass: + """Given a Dataclass type and a dict, return a populated instance. + + Args: + dataclass: the Class type to return an instance of + d: the dict to be used to populate the new instance + + Returns: + A constructed and populated dataclass instance. + + >>> from dataclasses import dataclass + >>> from datetime import date + + >>> @dataclass + ... class Record: + ... name: str + ... phone: str + ... address: str + ... age: int + ... member_since: date + ... + + >>> d = { + ... 'name': 'John Smith', + ... 'phone': '555-1234', + ... 'address': '994 Main St.', + ... 'age': 26, + ... 'member_since': date(2006, 5, 14), + ... } + + >>> dataclass_from_dict(Record, d) + Record(name='John Smith', phone='555-1234', address='994 Main St.', age=26, member_since=datetime.date(2006, 5, 14)) + """ + fields = {f.name for f in dataclasses.fields(dataclass) if f.init} + filtered_args = {k: v for k, v in d.items() if k in fields} + return dataclass(**filtered_args) + + +def dataclass_to_dict(dataclass: Dataclass) -> Dict[str, Any]: + """ + Returns: + A dict-representation of a valid dataclass. + + >>> from dataclasses import dataclass + >>> from datetime import date + + >>> @dataclass + ... class Record: + ... name: str + ... phone: str + ... address: str + ... age: int + ... member_since: date + ... + >>> r = Record(name='Jane Doe', phone='555-1232', address='998 Main St.', age=23, member_since=date(2008, 3, 1)) + >>> dataclass_to_dict(r) + {'name': 'Jane Doe', 'phone': '555-1232', 'address': '998 Main St.', 'age': 23, 'member_since': datetime.date(2008, 3, 1)} + """ + assert dataclasses.is_dataclass(dataclass) + return dataclasses.asdict(dataclass) + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/dict_utils.py b/src/pyutils/dict_utils.py index e5fbb48..c269723 100644 --- a/src/pyutils/dict_utils.py +++ b/src/pyutils/dict_utils.py @@ -7,6 +7,8 @@ from itertools import islice from typing import Any, Callable, Dict, Iterator, List, Tuple +from pyutils import dataclass_utils + def init_or_inc( d: Dict[Any, Any], @@ -361,6 +363,11 @@ def dict_to_key_value_lists(d: Dict[Any, Any]) -> Tuple[List[Any], List[Any]]: return r +dict_to_dataclass = dataclass_utils.dataclass_from_dict + +dict_from_dataclass = dataclass_utils.dataclass_to_dict + + if __name__ == '__main__': import doctest diff --git a/src/pyutils/parallelize/executors.py b/src/pyutils/parallelize/executors.py index 23fd6eb..a2877fa 100644 --- a/src/pyutils/parallelize/executors.py +++ b/src/pyutils/parallelize/executors.py @@ -49,18 +49,26 @@ import time import warnings from abc import ABC, abstractmethod from collections import defaultdict -from dataclasses import dataclass, fields +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Set import cloudpickle # type: ignore from overrides import overrides import pyutils.typez.histogram as hist -from pyutils import argparse_utils, config, math_utils, persistent, string_utils +from pyutils import ( + argparse_utils, + config, + dataclass_utils, + math_utils, + persistent, + string_utils, +) from pyutils.ansi import bg, fg, reset, underline from pyutils.decorator_utils import singleton from pyutils.exec_utils import cmd_exitcode, cmd_in_background, run_silently from pyutils.parallelize.thread_utils import background_thread +from pyutils.typez import type_utils logger = logging.getLogger(__name__) @@ -1527,16 +1535,10 @@ class ConfigRemoteWorkerPoolProvider( self.remote_worker_pool = [] for record in json_remote_worker_pool['remote_worker_records']: self.remote_worker_pool.append( - self.dataclassFromDict(RemoteWorkerRecord, record) + dataclass_utils.dataclass_from_dict(RemoteWorkerRecord, record) ) assert len(self.remote_worker_pool) > 0 - @staticmethod - def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any: - fieldSet = {f.name for f in fields(clsName) if f.init} - filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet} - return clsName(**filteredArgDict) - @overrides def get_remote_workers(self) -> List[RemoteWorkerRecord]: return self.remote_worker_pool @@ -1548,7 +1550,7 @@ class ConfigRemoteWorkerPoolProvider( @staticmethod @overrides def get_filename() -> str: - return config.config['remote_worker_records_file'] + return type_utils.unwrap_optional(config.config['remote_worker_records_file']) @staticmethod @overrides