Changes towards splitting up the library and (maybe?) publishing on PyPi.
authorScott Gasch <[email protected]>
Tue, 30 Aug 2022 22:36:23 +0000 (15:36 -0700)
committerScott Gasch <[email protected]>
Tue, 30 Aug 2022 22:36:23 +0000 (15:36 -0700)
NOTICE
collect/shared_dict.py
config.py
tests/shared_dict_test.py

diff --git a/NOTICE b/NOTICE
index 59d5c112579b71de2c745fa42666894b834a7acb..d61aae6e361fbd56dd7f2ae13ba7f90c226ce0ba 100644 (file)
--- a/NOTICE
+++ b/NOTICE
@@ -41,9 +41,10 @@ contains URLs pointing at the source of the forked code.
 
   Scott's modifications include:
     + Adding a unittest (tests/shared_dict_test.py),
+    + Added type hints,
+    + Changes to locking scope,
     + Minor cleanup and style tweaks,
-    + Added sphinx style pydocs,
-    + Added type hints.
+    + Added sphinx style pydocs.
 
   3. The timeout decortator in decorator_utils.py is based on original
   work published in ActiveState code recipes and covered by the PSF
index 3207927ed2f550b6516bce0c1b72fd96d7581ba4..dccae4c933a2c433eb299f05a85b36c1e029bce1 100644 (file)
@@ -33,55 +33,81 @@ from multiprocessing import RLock, shared_memory
 from typing import (
     Any,
     Dict,
-    Generator,
+    Hashable,
     ItemsView,
     Iterator,
     KeysView,
     Optional,
+    Tuple,
     ValuesView,
 )
 
 
 class PickleSerializer:
-    def dumps(self, obj: dict) -> bytes:
+    """A serializer that uses pickling.  Used to read/write bytes in the shared
+    memory region and interpret them as a dict."""
+
+    def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
         try:
             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
         except pickle.PicklingError as e:
-            raise Exception(e)
+            raise Exception from e
 
-    def loads(self, data: bytes) -> dict:
+    def loads(self, data: bytes) -> Dict[Hashable, Any]:
         try:
             return pickle.loads(data)
         except pickle.UnpicklingError as e:
-            raise Exception(e)
+            raise Exception from e
 
 
-# TODO: protobuf serializer?
+# TODOs: profile the serializers and figure out the fastest one.  Can
+# we use a ChainMap to avoid the constant de/re-serialization of the
+# whole thing?
 
 
 class SharedDict(object):
+    """This class emulates the dict container but uses a
+    Multiprocessing.SharedMemory region to back the dict such that it
+    can be read and written by multiple independent processes at the
+    same time.  Because it constantly de/re-serializes the dict, it is
+    much slower than a normal dict.
+
+    """
+
     NULL_BYTE = b'\x00'
-    MPLOCK = RLock()
+    LOCK = RLock()
 
     def __init__(
         self,
-        name: str,
+        name: Optional[str] = None,
         size_bytes: Optional[int] = None,
     ) -> None:
-        super().__init__()
-        self.name = name
-        self._serializer = PickleSerializer()
+        """
+        Creates or attaches a shared dictionary back by a SharedMemory buffer.
+        For create semantics, a unique name (string) and a max dictionary size
+        (expressed in bytes) must be provided.  For attach semantics, these are
+        ignored.
+
+        The first process that creates the SharedDict is responsible for
+        (optionally) naming it and deciding the max size (in bytes) that
+        it may be.  It does this via args to the c'tor.
+
+        Subsequent processes may safely omit name and size args.
+
+        """
         assert size_bytes is None or size_bytes > 0
+        self._serializer = PickleSerializer()
         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
         self._ensure_memory_initialization()
-        self.lock = RLock()
+        self.name = self.shared_memory.name
 
     def get_name(self):
+        """Returns the name of the shared memory buffer backing the dict."""
         return self.name
 
     def _get_or_create_memory_block(
         self,
-        name: str,
+        name: Optional[str] = None,
         size_bytes: Optional[int] = None,
     ) -> shared_memory.SharedMemory:
         try:
@@ -91,59 +117,77 @@ class SharedDict(object):
             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
 
     def _ensure_memory_initialization(self):
-        memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
-        if memory_is_empty:
-            self.clear()
+        with SharedDict.LOCK:
+            memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
+            if memory_is_empty:
+                self.clear()
+
+    def _write_memory(self, db: Dict[Hashable, Any]) -> None:
+        data = self._serializer.dumps(db)
+        with SharedDict.LOCK:
+            try:
+                self.shared_memory.buf[: len(data)] = data
+            except ValueError as e:
+                raise ValueError("exceeds available storage") from e
+
+    def _read_memory(self) -> Dict[Hashable, Any]:
+        with SharedDict.LOCK:
+            return self._serializer.loads(self.shared_memory.buf.tobytes())
+
+    @contextmanager
+    def _modify_dict(self):
+        with SharedDict.LOCK:
+            db = self._read_memory()
+            yield db
+            self._write_memory(db)
 
     def close(self) -> None:
+        """Unmap the shared dict and memory behind it from this
+        process.  Called by automatically __del__"""
         if not hasattr(self, 'shared_memory'):
             return
         self.shared_memory.close()
 
     def cleanup(self) -> None:
+        """Unlink the shared dict and memory behind it.  Only the last process should
+        invoke this.  Not called automatically."""
         if not hasattr(self, 'shared_memory'):
             return
-        with SharedDict.MPLOCK:
+        with SharedDict.LOCK:
             self.shared_memory.unlink()
 
     def clear(self) -> None:
-        self._save_memory({})
-
-    def popitem(self, last: Optional[bool] = None) -> Any:
-        with self._modify_db() as db:
-            return db.popitem()
+        """Clear the dict."""
+        self._write_memory({})
 
-    @contextmanager
-    def _modify_db(self) -> Generator:
-        with SharedDict.MPLOCK:
-            db = self._read_memory()
-            yield db
-            self._save_memory(db)
+    def copy(self) -> Dict[Hashable, Any]:
+        """Returns a shallow copy of the dict."""
+        return self._read_memory()
 
-    def __getitem__(self, key: str) -> Any:
+    def __getitem__(self, key: Hashable) -> Any:
         return self._read_memory()[key]
 
-    def __setitem__(self, key: str, value: Any) -> None:
-        with self._modify_db() as db:
+    def __setitem__(self, key: Hashable, value: Any) -> None:
+        with self._modify_dict() as db:
             db[key] = value
 
     def __len__(self) -> int:
         return len(self._read_memory())
 
-    def __delitem__(self, key: str) -> None:
-        with self._modify_db() as db:
+    def __delitem__(self, key: Hashable) -> None:
+        with self._modify_dict() as db:
             del db[key]
 
-    def __iter__(self) -> Iterator:
+    def __iter__(self) -> Iterator[Hashable]:
         return iter(self._read_memory())
 
-    def __reversed__(self):
+    def __reversed__(self) -> Iterator[Hashable]:
         return reversed(self._read_memory())
 
     def __del__(self) -> None:
         self.close()
 
-    def __contains__(self, key: str) -> bool:
+    def __contains__(self, key: Hashable) -> bool:
         return key in self._read_memory()
 
     def __eq__(self, other: Any) -> bool:
@@ -159,43 +203,38 @@ class SharedDict(object):
         return repr(self._read_memory())
 
     def get(self, key: str, default: Optional[Any] = None) -> Any:
+        """Gets the value associated with key or a default."""
         return self._read_memory().get(key, default)
 
-    def keys(self) -> KeysView[Any]:
+    def keys(self) -> KeysView[Hashable]:
         return self._read_memory().keys()
 
     def values(self) -> ValuesView[Any]:
         return self._read_memory().values()
 
-    def items(self) -> ItemsView:
+    def items(self) -> ItemsView[Hashable, Any]:
         return self._read_memory().items()
 
-    def pop(self, key: str, default: Optional[Any] = None):
-        with self._modify_db() as db:
+    def popitem(self) -> Tuple[Hashable, Any]:
+        """Remove and return the last added item."""
+        with self._modify_dict() as db:
+            return db.popitem()
+
+    def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
+        """Remove and return the value associated with key or a default"""
+        with self._modify_dict() as db:
             if default is None:
                 return db.pop(key)
             return db.pop(key, default)
 
     def update(self, other=(), /, **kwds):
-        with self._modify_db() as db:
+        with self._modify_dict() as db:
             db.update(other, **kwds)
 
-    def setdefault(self, key: str, default: Optional[Any] = None):
-        with self._modify_db() as db:
+    def setdefault(self, key: Hashable, default: Optional[Any] = None):
+        with self._modify_dict() as db:
             return db.setdefault(key, default)
 
-    def _save_memory(self, db: Dict[str, Any]) -> None:
-        with SharedDict.MPLOCK:
-            data = self._serializer.dumps(db)
-            try:
-                self.shared_memory.buf[: len(data)] = data
-            except ValueError as exc:
-                raise ValueError("exceeds available storage") from exc
-
-    def _read_memory(self) -> Dict[str, Any]:
-        with SharedDict.MPLOCK:
-            return self._serializer.loads(self.shared_memory.buf.tobytes())
-
 
 if __name__ == '__main__':
     import doctest
index a98db575661621e3e7baf5c285cb71337f9d0d52..4d885149529901aedfa884ab602d3e1c3c096f17 100644 (file)
--- a/config.py
+++ b/config.py
@@ -90,8 +90,6 @@ import re
 import sys
 from typing import Any, Dict, List, Optional, Tuple
 
-import scott_secrets
-
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 
@@ -330,6 +328,46 @@ class Config:
             env = env[1:]
         return var, env, chunks
 
+    @staticmethod
+    def _to_bool(in_str: str) -> bool:
+        """
+        Args:
+            in_str: the string to convert to boolean
+
+        Returns:
+            A boolean equivalent of the original string based on its contents.
+            All conversion is case insensitive.  A positive boolean (True) is
+            returned if the string value is any of the following:
+
+            * "true"
+            * "t"
+            * "1"
+            * "yes"
+            * "y"
+            * "on"
+
+            Otherwise False is returned.
+
+        >>> to_bool('True')
+        True
+
+        >>> to_bool('1')
+        True
+
+        >>> to_bool('yes')
+        True
+
+        >>> to_bool('no')
+        False
+
+        >>> to_bool('huh?')
+        False
+
+        >>> to_bool('on')
+        True
+        """
+        return in_str.lower() in ("true", "1", "yes", "y", "t", "on")
+
     def _augment_sys_argv_from_environment_variables(self):
         """Internal.  Look at the system environment for variables that match
         commandline arg names.  This is done via some munging such that:
@@ -366,9 +404,7 @@ class Config:
                                 self.saved_messages.append(
                                     f'Initialized from environment: {var} = {value}'
                                 )
-                                from string_utils import to_bool
-
-                                if len(chunks) == 1 and to_bool(value):
+                                if len(chunks) == 1 and Config._to_bool(value):
                                     sys.argv.append(var)
                                 elif len(chunks) > 1:
                                     sys.argv.append(var)
@@ -421,8 +457,11 @@ class Config:
             if loadfile[:3] == 'zk:':
                 from kazoo.client import KazooClient
 
+                import scott_secrets
+
                 try:
                     if self.zk is None:
+
                         self.zk = KazooClient(
                             hosts=scott_secrets.ZOOKEEPER_NODES,
                             use_ssl=True,
@@ -545,6 +584,8 @@ class Config:
                     if not self.zk:
                         from kazoo.client import KazooClient
 
+                        import scott_secrets
+
                         self.zk = KazooClient(
                             hosts=scott_secrets.ZOOKEEPER_NODES,
                             use_ssl=True,
index 230bdb989c4c5e74c6d31cecebee3ea468b277c8..68a378800defd2e6c1a134927f6db4dae0adf6f6 100755 (executable)
@@ -4,6 +4,7 @@
 
 """shared_dict unittest."""
 
+import random
 import unittest
 
 import parallelize as p
@@ -14,13 +15,16 @@ from collect.shared_dict import SharedDict
 
 class SharedDictTest(unittest.TestCase):
     @p.parallelize(method=p.Method.PROCESS)
-    def doit(self, n: int, dict_name: str):
-        d = SharedDict(dict_name)
+    def doit(self, n: int, dict_name: str, parent_lock_id: int):
+        assert id(SharedDict.LOCK) == parent_lock_id
+        d = SharedDict(dict_name, None)
         try:
             msg = f'Hello from shard {n}'
-            d[n] = msg
-            self.assertTrue(n in d)
-            self.assertEqual(msg, d[n])
+            for x in range(0, 1000):
+                d[n] = msg
+                self.assertTrue(n in d)
+                self.assertEqual(msg, d[n])
+                y = d.get(random.randrange(0, 99), None)
             return n
         finally:
             d.close()
@@ -32,23 +36,25 @@ class SharedDictTest(unittest.TestCase):
             self.assertEqual(dict_name, d.get_name())
             results = []
             for n in range(100):
-                f = self.doit(n, d.get_name())
+                f = self.doit(n, d.get_name(), id(SharedDict.LOCK))
                 results.append(f)
             smart_future.wait_all(results)
             for f in results:
                 self.assertTrue(f.wrapped_future.done())
             for k in d:
                 self.assertEqual(d[k], f'Hello from shard {k}')
+            assert len(d) == 100
         finally:
             d.close()
             d.cleanup()
 
     @p.parallelize(method=p.Method.PROCESS)
-    def add_one(self, name: str):
+    def add_one(self, name: str, expected_lock_id: int):
         d = SharedDict(name)
+        self.assertEqual(id(SharedDict.LOCK), expected_lock_id)
         try:
             for x in range(1000):
-                with SharedDict.MPLOCK:
+                with SharedDict.LOCK:
                     d["sum"] += 1
         finally:
             d.close()
@@ -60,7 +66,7 @@ class SharedDictTest(unittest.TestCase):
             d["sum"] = 0
             results = []
             for n in range(10):
-                f = self.add_one(d.get_name())
+                f = self.add_one(d.get_name(), id(SharedDict.LOCK))
                 results.append(f)
             smart_future.wait_all(results)
             self.assertEqual(10000, d["sum"])