Add a MP shared dict.
authorScott <[email protected]>
Mon, 31 Jan 2022 01:25:59 +0000 (17:25 -0800)
committerScott <[email protected]>
Mon, 31 Jan 2022 01:25:59 +0000 (17:25 -0800)
collect/shared_dict.py [new file with mode: 0644]

diff --git a/collect/shared_dict.py b/collect/shared_dict.py
new file mode 100644 (file)
index 0000000..1c70c3d
--- /dev/null
@@ -0,0 +1,170 @@
+#!/usr/bin/env python3
+
+import pickle
+from contextlib import contextmanager
+from functools import wraps
+from multiprocessing import shared_memory, Lock
+from typing import (
+    Any,
+    Dict,
+    Generator,
+    KeysView,
+    ItemsView,
+    Iterator,
+    Optional,
+    ValuesView,
+)
+
+from decorator_utils import synchronized
+
+
+class PickleSerializer:
+    def dumps(self, obj: dict) -> bytes:
+        try:
+            return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
+        except pickle.PicklingError as e:
+            raise Exception(e)
+
+    def loads(self, data: bytes) -> dict:
+        try:
+            return pickle.loads(data)
+        except pickle.UnpicklingError as e:
+            raise Exception(e)
+
+
+# TODO: protobuf serializer?
+
+
+class SharedDict(object):
+    NULL_BYTE = b'\x00'
+    MPLOCK = Lock()
+
+    def __init__(
+        self,
+        name: str,
+        size: int,
+    ) -> None:
+        super().__init__()
+        self._serializer = PickleSerializer()
+        self.shared_memory = self._get_or_create_memory_block(name, size)
+        self._ensure_memory_initialization()
+
+    def _get_or_create_memory_block(
+        self, name: str, size: int
+    ) -> shared_memory.SharedMemory:
+        try:
+            return shared_memory.SharedMemory(name=name)
+        except FileNotFoundError:
+            return shared_memory.SharedMemory(name=name, create=True, size=size)
+
+    def _ensure_memory_initialization(self):
+        memory_is_empty = (
+            bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
+        )
+        if memory_is_empty:
+            self.clear()
+
+    def close(self) -> None:
+        if not hasattr(self, 'shared_memory'):
+            return
+        self.shared_memory.close()
+
+    def cleanup(self) -> None:
+        if not hasattr(self, 'shared_memory'):
+            return
+        self.shared_memory.unlink()
+
+    @synchronized(MPLOCK)
+    def clear(self) -> None:
+        self._save_memory({})
+
+    def popitem(self, last: Optional[bool] = None) -> Any:
+        with self._modify_db() as db:
+            return db.popitem()
+
+    @synchronized(MPLOCK)
+    @contextmanager
+    def _modify_db(self) -> Generator:
+        db = self._read_memory()
+        yield db
+        self._save_memory(db)
+
+    def __getitem__(self, key: str) -> Any:
+        return self._read_memory()[key]
+
+    def __setitem__(self, key: str, value: Any) -> None:
+        with self._modify_db() as db:
+            db[key] = value
+
+    def __len__(self) -> int:
+        return len(self._read_memory())
+
+    def __delitem__(self, key: str) -> None:
+        with self._modify_db() as db:
+            del db[key]
+
+    def __iter__(self) -> Iterator:
+        return iter(self._read_memory())
+
+    def __reversed__(self):
+        return reversed(self._read_memory())
+
+    def __del__(self) -> None:
+        self.close()
+
+    def __contains__(self, key: str) -> bool:
+        return key in self._read_memory()
+
+    def __eq__(self, other: Any) -> bool:
+        return self._read_memory() == other
+
+    def __ne__(self, other: Any) -> bool:
+        return self._read_memory() != other
+
+    def __str__(self) -> str:
+        return str(self._read_memory())
+
+    def __repr__(self) -> str:
+        return repr(self._read_memory())
+
+    def get(self, key: str, default: Optional[Any] = None) -> Any:
+        return self._read_memory().get(key, default)
+
+    def keys(self) -> KeysView[Any]:
+        return self._read_memory().keys()
+
+    def values(self) -> ValuesView[Any]:
+        return self._read_memory().values()
+
+    def items(self) -> ItemsView:
+        return self._read_memory().items()
+
+    def pop(self, key: str, default: Optional[Any] = None):
+        with self._modify_db() as db:
+            if default is None:
+                return db.pop(key)
+            return db.pop(key, default)
+
+    def update(self, other=(), /, **kwds):
+        with self._modify_db() as db:
+            db.update(other, **kwds)
+
+    def setdefault(self, key: str, default: Optional[Any] = None):
+        with self._modify_db() as db:
+            return db.setdefault(key, default)
+
+    def _save_memory(self, db: Dict[str, Any]) -> None:
+        data = self._serializer.dumps(db)
+        try:
+            self.shared_memory.buf[: len(data)] = data
+        except ValueError as exc:
+            raise ValueError("exceeds available storage") from exc
+
+    def _read_memory(self) -> Dict[str, Any]:
+        return self._serializer.loads(self.shared_memory.buf.tobytes())
+
+
+if __name__ == '__main__':
+    import doctest
+
+    doctest.testmod()