Add dataclass_utils for some simple dataclass wrappers and annotation.
authorScott Gasch <[email protected]>
Wed, 26 Oct 2022 03:58:53 +0000 (20:58 -0700)
committerScott Gasch <[email protected]>
Wed, 26 Oct 2022 03:58:53 +0000 (20:58 -0700)
src/pyutils/dataclass_utils.py [new file with mode: 0644]
src/pyutils/dict_utils.py
src/pyutils/parallelize/executors.py

diff --git a/src/pyutils/dataclass_utils.py b/src/pyutils/dataclass_utils.py
new file mode 100644 (file)
index 0000000..22d483e
--- /dev/null
@@ -0,0 +1,90 @@
+#!/usr/bin/env python3
+
+"""Utilities for dealing with Dataclasses.  A non-official type hint and some
+friendly wrappers around conversion to/from Dicts."""
+
+import dataclasses
+from typing import Any, Dict, Protocol, Type
+
+
+class Dataclass(Protocol):
+    """Dataclass isn't really a first class type and therefore there is no offical
+    type hint for Dataclasses in Python (yet).  If you need one, here's a suitable
+    stand in.  Example usage::
+
+        def f(d: Dataclass) -> Any:
+            pass
+
+        def g(d: Dict[str, Any]) -> Dataclass:
+            pass
+    """
+
+    __dataclass_fields__: Dict
+
+
+def dataclass_from_dict(dataclass: Type[Dataclass], d: Dict[str, Any]) -> Dataclass:
+    """Given a Dataclass type and a dict, return a populated instance.
+
+    Args:
+        dataclass: the Class type to return an instance of
+        d: the dict to be used to populate the new instance
+
+    Returns:
+        A constructed and populated dataclass instance.
+
+    >>> from dataclasses import dataclass
+    >>> from datetime import date
+
+    >>> @dataclass
+    ... class Record:
+    ...     name: str
+    ...     phone: str
+    ...     address: str
+    ...     age: int
+    ...     member_since: date
+    ...
+
+    >>> d = {
+    ...         'name': 'John Smith',
+    ...         'phone': '555-1234',
+    ...         'address': '994 Main St.',
+    ...         'age': 26,
+    ...         'member_since': date(2006, 5, 14),
+    ...     }
+
+    >>> dataclass_from_dict(Record, d)
+    Record(name='John Smith', phone='555-1234', address='994 Main St.', age=26, member_since=datetime.date(2006, 5, 14))
+    """
+    fields = {f.name for f in dataclasses.fields(dataclass) if f.init}
+    filtered_args = {k: v for k, v in d.items() if k in fields}
+    return dataclass(**filtered_args)
+
+
+def dataclass_to_dict(dataclass: Dataclass) -> Dict[str, Any]:
+    """
+    Returns:
+        A dict-representation of a valid dataclass.
+
+    >>> from dataclasses import dataclass
+    >>> from datetime import date
+
+    >>> @dataclass
+    ... class Record:
+    ...     name: str
+    ...     phone: str
+    ...     address: str
+    ...     age: int
+    ...     member_since: date
+    ...
+    >>> r = Record(name='Jane Doe', phone='555-1232', address='998 Main St.', age=23, member_since=date(2008, 3, 1))
+    >>> dataclass_to_dict(r)
+    {'name': 'Jane Doe', 'phone': '555-1232', 'address': '998 Main St.', 'age': 23, 'member_since': datetime.date(2008, 3, 1)}
+    """
+    assert dataclasses.is_dataclass(dataclass)
+    return dataclasses.asdict(dataclass)
+
+
+if __name__ == '__main__':
+    import doctest
+
+    doctest.testmod()
index e5fbb48a38800a938df3aea16bd99ada888ece72..c269723a8d3684348f8ec07434fb6588d5b40bf7 100644 (file)
@@ -7,6 +7,8 @@
 from itertools import islice
 from typing import Any, Callable, Dict, Iterator, List, Tuple
 
+from pyutils import dataclass_utils
+
 
 def init_or_inc(
     d: Dict[Any, Any],
@@ -361,6 +363,11 @@ def dict_to_key_value_lists(d: Dict[Any, Any]) -> Tuple[List[Any], List[Any]]:
     return r
 
 
+dict_to_dataclass = dataclass_utils.dataclass_from_dict
+
+dict_from_dataclass = dataclass_utils.dataclass_to_dict
+
+
 if __name__ == '__main__':
     import doctest
 
index 23fd6eb96c3a9a3198691c705c455902d982fce8..a2877fa9eb19d958cf7263828a1fc0598e542d2e 100644 (file)
@@ -49,18 +49,26 @@ import time
 import warnings
 from abc import ABC, abstractmethod
 from collections import defaultdict
-from dataclasses import dataclass, fields
+from dataclasses import dataclass
 from typing import Any, Callable, Dict, List, Optional, Set
 
 import cloudpickle  # type: ignore
 from overrides import overrides
 
 import pyutils.typez.histogram as hist
-from pyutils import argparse_utils, config, math_utils, persistent, string_utils
+from pyutils import (
+    argparse_utils,
+    config,
+    dataclass_utils,
+    math_utils,
+    persistent,
+    string_utils,
+)
 from pyutils.ansi import bg, fg, reset, underline
 from pyutils.decorator_utils import singleton
 from pyutils.exec_utils import cmd_exitcode, cmd_in_background, run_silently
 from pyutils.parallelize.thread_utils import background_thread
+from pyutils.typez import type_utils
 
 logger = logging.getLogger(__name__)
 
@@ -1527,16 +1535,10 @@ class ConfigRemoteWorkerPoolProvider(
         self.remote_worker_pool = []
         for record in json_remote_worker_pool['remote_worker_records']:
             self.remote_worker_pool.append(
-                self.dataclassFromDict(RemoteWorkerRecord, record)
+                dataclass_utils.dataclass_from_dict(RemoteWorkerRecord, record)
             )
         assert len(self.remote_worker_pool) > 0
 
-    @staticmethod
-    def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any:
-        fieldSet = {f.name for f in fields(clsName) if f.init}
-        filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
-        return clsName(**filteredArgDict)
-
     @overrides
     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
         return self.remote_worker_pool
@@ -1548,7 +1550,7 @@ class ConfigRemoteWorkerPoolProvider(
     @staticmethod
     @overrides
     def get_filename() -> str:
-        return config.config['remote_worker_records_file']
+        return type_utils.unwrap_optional(config.config['remote_worker_records_file'])
 
     @staticmethod
     @overrides