--- /dev/null
+#!/usr/bin/env python3
+
+import pickle
+from contextlib import contextmanager
+from functools import wraps
+from multiprocessing import shared_memory, Lock
+from typing import (
+ Any,
+ Dict,
+ Generator,
+ KeysView,
+ ItemsView,
+ Iterator,
+ Optional,
+ ValuesView,
+)
+
+from decorator_utils import synchronized
+
+
+class PickleSerializer:
+ def dumps(self, obj: dict) -> bytes:
+ try:
+ return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
+ except pickle.PicklingError as e:
+ raise Exception(e)
+
+ def loads(self, data: bytes) -> dict:
+ try:
+ return pickle.loads(data)
+ except pickle.UnpicklingError as e:
+ raise Exception(e)
+
+
+# TODO: protobuf serializer?
+
+
+class SharedDict(object):
+ NULL_BYTE = b'\x00'
+ MPLOCK = Lock()
+
+ def __init__(
+ self,
+ name: str,
+ size: int,
+ ) -> None:
+ super().__init__()
+ self._serializer = PickleSerializer()
+ self.shared_memory = self._get_or_create_memory_block(name, size)
+ self._ensure_memory_initialization()
+
+ def _get_or_create_memory_block(
+ self, name: str, size: int
+ ) -> shared_memory.SharedMemory:
+ try:
+ return shared_memory.SharedMemory(name=name)
+ except FileNotFoundError:
+ return shared_memory.SharedMemory(name=name, create=True, size=size)
+
+ def _ensure_memory_initialization(self):
+ memory_is_empty = (
+ bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
+ )
+ if memory_is_empty:
+ self.clear()
+
+ def close(self) -> None:
+ if not hasattr(self, 'shared_memory'):
+ return
+ self.shared_memory.close()
+
+ def cleanup(self) -> None:
+ if not hasattr(self, 'shared_memory'):
+ return
+ self.shared_memory.unlink()
+
+ @synchronized(MPLOCK)
+ def clear(self) -> None:
+ self._save_memory({})
+
+ def popitem(self, last: Optional[bool] = None) -> Any:
+ with self._modify_db() as db:
+ return db.popitem()
+
+ @synchronized(MPLOCK)
+ @contextmanager
+ def _modify_db(self) -> Generator:
+ db = self._read_memory()
+ yield db
+ self._save_memory(db)
+
+ def __getitem__(self, key: str) -> Any:
+ return self._read_memory()[key]
+
+ def __setitem__(self, key: str, value: Any) -> None:
+ with self._modify_db() as db:
+ db[key] = value
+
+ def __len__(self) -> int:
+ return len(self._read_memory())
+
+ def __delitem__(self, key: str) -> None:
+ with self._modify_db() as db:
+ del db[key]
+
+ def __iter__(self) -> Iterator:
+ return iter(self._read_memory())
+
+ def __reversed__(self):
+ return reversed(self._read_memory())
+
+ def __del__(self) -> None:
+ self.close()
+
+ def __contains__(self, key: str) -> bool:
+ return key in self._read_memory()
+
+ def __eq__(self, other: Any) -> bool:
+ return self._read_memory() == other
+
+ def __ne__(self, other: Any) -> bool:
+ return self._read_memory() != other
+
+ def __str__(self) -> str:
+ return str(self._read_memory())
+
+ def __repr__(self) -> str:
+ return repr(self._read_memory())
+
+ def get(self, key: str, default: Optional[Any] = None) -> Any:
+ return self._read_memory().get(key, default)
+
+ def keys(self) -> KeysView[Any]:
+ return self._read_memory().keys()
+
+ def values(self) -> ValuesView[Any]:
+ return self._read_memory().values()
+
+ def items(self) -> ItemsView:
+ return self._read_memory().items()
+
+ def pop(self, key: str, default: Optional[Any] = None):
+ with self._modify_db() as db:
+ if default is None:
+ return db.pop(key)
+ return db.pop(key, default)
+
+ def update(self, other=(), /, **kwds):
+ with self._modify_db() as db:
+ db.update(other, **kwds)
+
+ def setdefault(self, key: str, default: Optional[Any] = None):
+ with self._modify_db() as db:
+ return db.setdefault(key, default)
+
+ def _save_memory(self, db: Dict[str, Any]) -> None:
+ data = self._serializer.dumps(db)
+ try:
+ self.shared_memory.buf[: len(data)] = data
+ except ValueError as exc:
+ raise ValueError("exceeds available storage") from exc
+
+ def _read_memory(self) -> Dict[str, Any]:
+ return self._serializer.loads(self.shared_memory.buf.tobytes())
+
+
+if __name__ == '__main__':
+ import doctest
+
+ doctest.testmod()