Changes towards splitting up the library and (maybe?) publishing on PyPi.
[python_utils.git] / collect / shared_dict.py
index ac390bc600e769a94c2205440883257a665917bb..dccae4c933a2c433eb299f05a85b36c1e029bce1 100644 (file)
@@ -4,7 +4,7 @@
 The MIT License (MIT)
 
 Copyright (c) 2020 LuizaLabs
-Additions Copyright (c) 2022 Scott Gasch
+Additions/Modifications Copyright (c) 2022 Scott Gasch
 
 Permission is hereby granted, free of charge, to any person obtaining a copy
 of this software and associated documentation files (the "Software"), to deal
@@ -29,62 +29,85 @@ This class is based on https://github.com/luizalabs/shared-memory-dict
 
 import pickle
 from contextlib import contextmanager
-from functools import wraps
 from multiprocessing import RLock, shared_memory
 from typing import (
     Any,
     Dict,
-    Generator,
+    Hashable,
     ItemsView,
     Iterator,
     KeysView,
     Optional,
+    Tuple,
     ValuesView,
 )
 
-from decorator_utils import synchronized
-
 
 class PickleSerializer:
-    def dumps(self, obj: dict) -> bytes:
+    """A serializer that uses pickling.  Used to read/write bytes in the shared
+    memory region and interpret them as a dict."""
+
+    def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
         try:
             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
         except pickle.PicklingError as e:
-            raise Exception(e)
+            raise Exception from e
 
-    def loads(self, data: bytes) -> dict:
+    def loads(self, data: bytes) -> Dict[Hashable, Any]:
         try:
             return pickle.loads(data)
         except pickle.UnpicklingError as e:
-            raise Exception(e)
+            raise Exception from e
 
 
-# TODO: protobuf serializer?
+# TODOs: profile the serializers and figure out the fastest one.  Can
+# we use a ChainMap to avoid the constant de/re-serialization of the
+# whole thing?
 
 
 class SharedDict(object):
+    """This class emulates the dict container but uses a
+    Multiprocessing.SharedMemory region to back the dict such that it
+    can be read and written by multiple independent processes at the
+    same time.  Because it constantly de/re-serializes the dict, it is
+    much slower than a normal dict.
+
+    """
+
     NULL_BYTE = b'\x00'
-    MPLOCK = RLock()
+    LOCK = RLock()
 
     def __init__(
         self,
-        name: str,
+        name: Optional[str] = None,
         size_bytes: Optional[int] = None,
     ) -> None:
-        super().__init__()
-        self.name = name
-        self._serializer = PickleSerializer()
+        """
+        Creates or attaches a shared dictionary back by a SharedMemory buffer.
+        For create semantics, a unique name (string) and a max dictionary size
+        (expressed in bytes) must be provided.  For attach semantics, these are
+        ignored.
+
+        The first process that creates the SharedDict is responsible for
+        (optionally) naming it and deciding the max size (in bytes) that
+        it may be.  It does this via args to the c'tor.
+
+        Subsequent processes may safely omit name and size args.
+
+        """
         assert size_bytes is None or size_bytes > 0
+        self._serializer = PickleSerializer()
         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
         self._ensure_memory_initialization()
-        self.lock = RLock()
+        self.name = self.shared_memory.name
 
     def get_name(self):
+        """Returns the name of the shared memory buffer backing the dict."""
         return self.name
 
     def _get_or_create_memory_block(
         self,
-        name: str,
+        name: Optional[str] = None,
         size_bytes: Optional[int] = None,
     ) -> shared_memory.SharedMemory:
         try:
@@ -94,60 +117,77 @@ class SharedDict(object):
             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
 
     def _ensure_memory_initialization(self):
-        memory_is_empty = (
-            bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
-        )
-        if memory_is_empty:
-            self.clear()
+        with SharedDict.LOCK:
+            memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
+            if memory_is_empty:
+                self.clear()
+
+    def _write_memory(self, db: Dict[Hashable, Any]) -> None:
+        data = self._serializer.dumps(db)
+        with SharedDict.LOCK:
+            try:
+                self.shared_memory.buf[: len(data)] = data
+            except ValueError as e:
+                raise ValueError("exceeds available storage") from e
+
+    def _read_memory(self) -> Dict[Hashable, Any]:
+        with SharedDict.LOCK:
+            return self._serializer.loads(self.shared_memory.buf.tobytes())
+
+    @contextmanager
+    def _modify_dict(self):
+        with SharedDict.LOCK:
+            db = self._read_memory()
+            yield db
+            self._write_memory(db)
 
     def close(self) -> None:
+        """Unmap the shared dict and memory behind it from this
+        process.  Called by automatically __del__"""
         if not hasattr(self, 'shared_memory'):
             return
         self.shared_memory.close()
 
     def cleanup(self) -> None:
+        """Unlink the shared dict and memory behind it.  Only the last process should
+        invoke this.  Not called automatically."""
         if not hasattr(self, 'shared_memory'):
             return
-        self.shared_memory.unlink()
+        with SharedDict.LOCK:
+            self.shared_memory.unlink()
 
     def clear(self) -> None:
-        self._save_memory({})
-
-    def popitem(self, last: Optional[bool] = None) -> Any:
-        with self._modify_db() as db:
-            return db.popitem()
+        """Clear the dict."""
+        self._write_memory({})
 
-    @contextmanager
-    def _modify_db(self) -> Generator:
-        with SharedDict.MPLOCK:
-            db = self._read_memory()
-            yield db
-            self._save_memory(db)
+    def copy(self) -> Dict[Hashable, Any]:
+        """Returns a shallow copy of the dict."""
+        return self._read_memory()
 
-    def __getitem__(self, key: str) -> Any:
+    def __getitem__(self, key: Hashable) -> Any:
         return self._read_memory()[key]
 
-    def __setitem__(self, key: str, value: Any) -> None:
-        with self._modify_db() as db:
+    def __setitem__(self, key: Hashable, value: Any) -> None:
+        with self._modify_dict() as db:
             db[key] = value
 
     def __len__(self) -> int:
         return len(self._read_memory())
 
-    def __delitem__(self, key: str) -> None:
-        with self._modify_db() as db:
+    def __delitem__(self, key: Hashable) -> None:
+        with self._modify_dict() as db:
             del db[key]
 
-    def __iter__(self) -> Iterator:
+    def __iter__(self) -> Iterator[Hashable]:
         return iter(self._read_memory())
 
-    def __reversed__(self):
+    def __reversed__(self) -> Iterator[Hashable]:
         return reversed(self._read_memory())
 
     def __del__(self) -> None:
         self.close()
 
-    def __contains__(self, key: str) -> bool:
+    def __contains__(self, key: Hashable) -> bool:
         return key in self._read_memory()
 
     def __eq__(self, other: Any) -> bool:
@@ -163,43 +203,38 @@ class SharedDict(object):
         return repr(self._read_memory())
 
     def get(self, key: str, default: Optional[Any] = None) -> Any:
+        """Gets the value associated with key or a default."""
         return self._read_memory().get(key, default)
 
-    def keys(self) -> KeysView[Any]:
+    def keys(self) -> KeysView[Hashable]:
         return self._read_memory().keys()
 
     def values(self) -> ValuesView[Any]:
         return self._read_memory().values()
 
-    def items(self) -> ItemsView:
+    def items(self) -> ItemsView[Hashable, Any]:
         return self._read_memory().items()
 
-    def pop(self, key: str, default: Optional[Any] = None):
-        with self._modify_db() as db:
+    def popitem(self) -> Tuple[Hashable, Any]:
+        """Remove and return the last added item."""
+        with self._modify_dict() as db:
+            return db.popitem()
+
+    def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
+        """Remove and return the value associated with key or a default"""
+        with self._modify_dict() as db:
             if default is None:
                 return db.pop(key)
             return db.pop(key, default)
 
     def update(self, other=(), /, **kwds):
-        with self._modify_db() as db:
+        with self._modify_dict() as db:
             db.update(other, **kwds)
 
-    def setdefault(self, key: str, default: Optional[Any] = None):
-        with self._modify_db() as db:
+    def setdefault(self, key: Hashable, default: Optional[Any] = None):
+        with self._modify_dict() as db:
             return db.setdefault(key, default)
 
-    def _save_memory(self, db: Dict[str, Any]) -> None:
-        with SharedDict.MPLOCK:
-            data = self._serializer.dumps(db)
-            try:
-                self.shared_memory.buf[: len(data)] = data
-            except ValueError as exc:
-                raise ValueError("exceeds available storage") from exc
-
-    def _read_memory(self) -> Dict[str, Any]:
-        with SharedDict.MPLOCK:
-            return self._serializer.loads(self.shared_memory.buf.tobytes())
-
 
 if __name__ == '__main__':
     import doctest