#!/usr/bin/env python3 import pickle from contextlib import contextmanager from functools import wraps from multiprocessing import shared_memory, Lock from typing import ( Any, Dict, Generator, KeysView, ItemsView, Iterator, Optional, ValuesView, ) from decorator_utils import synchronized class PickleSerializer: def dumps(self, obj: dict) -> bytes: try: return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) except pickle.PicklingError as e: raise Exception(e) def loads(self, data: bytes) -> dict: try: return pickle.loads(data) except pickle.UnpicklingError as e: raise Exception(e) # TODO: protobuf serializer? class SharedDict(object): NULL_BYTE = b'\x00' MPLOCK = Lock() def __init__( self, name: str, size: int, ) -> None: super().__init__() self._serializer = PickleSerializer() self.shared_memory = self._get_or_create_memory_block(name, size) self._ensure_memory_initialization() def _get_or_create_memory_block( self, name: str, size: int ) -> shared_memory.SharedMemory: try: return shared_memory.SharedMemory(name=name) except FileNotFoundError: return shared_memory.SharedMemory(name=name, create=True, size=size) def _ensure_memory_initialization(self): memory_is_empty = ( bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b'' ) if memory_is_empty: self.clear() def close(self) -> None: if not hasattr(self, 'shared_memory'): return self.shared_memory.close() def cleanup(self) -> None: if not hasattr(self, 'shared_memory'): return self.shared_memory.unlink() @synchronized(MPLOCK) def clear(self) -> None: self._save_memory({}) def popitem(self, last: Optional[bool] = None) -> Any: with self._modify_db() as db: return db.popitem() @synchronized(MPLOCK) @contextmanager def _modify_db(self) -> Generator: db = self._read_memory() yield db self._save_memory(db) def __getitem__(self, key: str) -> Any: return self._read_memory()[key] def __setitem__(self, key: str, value: Any) -> None: with self._modify_db() as db: db[key] = value def __len__(self) -> int: return len(self._read_memory()) def __delitem__(self, key: str) -> None: with self._modify_db() as db: del db[key] def __iter__(self) -> Iterator: return iter(self._read_memory()) def __reversed__(self): return reversed(self._read_memory()) def __del__(self) -> None: self.close() def __contains__(self, key: str) -> bool: return key in self._read_memory() def __eq__(self, other: Any) -> bool: return self._read_memory() == other def __ne__(self, other: Any) -> bool: return self._read_memory() != other def __str__(self) -> str: return str(self._read_memory()) def __repr__(self) -> str: return repr(self._read_memory()) def get(self, key: str, default: Optional[Any] = None) -> Any: return self._read_memory().get(key, default) def keys(self) -> KeysView[Any]: return self._read_memory().keys() def values(self) -> ValuesView[Any]: return self._read_memory().values() def items(self) -> ItemsView: return self._read_memory().items() def pop(self, key: str, default: Optional[Any] = None): with self._modify_db() as db: if default is None: return db.pop(key) return db.pop(key, default) def update(self, other=(), /, **kwds): with self._modify_db() as db: db.update(other, **kwds) def setdefault(self, key: str, default: Optional[Any] = None): with self._modify_db() as db: return db.setdefault(key, default) def _save_memory(self, db: Dict[str, Any]) -> None: data = self._serializer.dumps(db) try: self.shared_memory.buf[: len(data)] = data except ValueError as exc: raise ValueError("exceeds available storage") from exc def _read_memory(self) -> Dict[str, Any]: return self._serializer.loads(self.shared_memory.buf.tobytes()) if __name__ == '__main__': import doctest doctest.testmod()