From: Scott Date: Mon, 31 Jan 2022 01:25:59 +0000 (-0800) Subject: Add a MP shared dict. X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=53cc8afd5cedb13c201f00a0f3c64d070078a8d5;p=python_utils.git Add a MP shared dict. --- diff --git a/collect/shared_dict.py b/collect/shared_dict.py new file mode 100644 index 0000000..1c70c3d --- /dev/null +++ b/collect/shared_dict.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +import pickle +from contextlib import contextmanager +from functools import wraps +from multiprocessing import shared_memory, Lock +from typing import ( + Any, + Dict, + Generator, + KeysView, + ItemsView, + Iterator, + Optional, + ValuesView, +) + +from decorator_utils import synchronized + + +class PickleSerializer: + def dumps(self, obj: dict) -> bytes: + try: + return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) + except pickle.PicklingError as e: + raise Exception(e) + + def loads(self, data: bytes) -> dict: + try: + return pickle.loads(data) + except pickle.UnpicklingError as e: + raise Exception(e) + + +# TODO: protobuf serializer? + + +class SharedDict(object): + NULL_BYTE = b'\x00' + MPLOCK = Lock() + + def __init__( + self, + name: str, + size: int, + ) -> None: + super().__init__() + self._serializer = PickleSerializer() + self.shared_memory = self._get_or_create_memory_block(name, size) + self._ensure_memory_initialization() + + def _get_or_create_memory_block( + self, name: str, size: int + ) -> shared_memory.SharedMemory: + try: + return shared_memory.SharedMemory(name=name) + except FileNotFoundError: + return shared_memory.SharedMemory(name=name, create=True, size=size) + + def _ensure_memory_initialization(self): + memory_is_empty = ( + bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b'' + ) + if memory_is_empty: + self.clear() + + def close(self) -> None: + if not hasattr(self, 'shared_memory'): + return + self.shared_memory.close() + + def cleanup(self) -> None: + if not hasattr(self, 'shared_memory'): + return + self.shared_memory.unlink() + + @synchronized(MPLOCK) + def clear(self) -> None: + self._save_memory({}) + + def popitem(self, last: Optional[bool] = None) -> Any: + with self._modify_db() as db: + return db.popitem() + + @synchronized(MPLOCK) + @contextmanager + def _modify_db(self) -> Generator: + db = self._read_memory() + yield db + self._save_memory(db) + + def __getitem__(self, key: str) -> Any: + return self._read_memory()[key] + + def __setitem__(self, key: str, value: Any) -> None: + with self._modify_db() as db: + db[key] = value + + def __len__(self) -> int: + return len(self._read_memory()) + + def __delitem__(self, key: str) -> None: + with self._modify_db() as db: + del db[key] + + def __iter__(self) -> Iterator: + return iter(self._read_memory()) + + def __reversed__(self): + return reversed(self._read_memory()) + + def __del__(self) -> None: + self.close() + + def __contains__(self, key: str) -> bool: + return key in self._read_memory() + + def __eq__(self, other: Any) -> bool: + return self._read_memory() == other + + def __ne__(self, other: Any) -> bool: + return self._read_memory() != other + + def __str__(self) -> str: + return str(self._read_memory()) + + def __repr__(self) -> str: + return repr(self._read_memory()) + + def get(self, key: str, default: Optional[Any] = None) -> Any: + return self._read_memory().get(key, default) + + def keys(self) -> KeysView[Any]: + return self._read_memory().keys() + + def values(self) -> ValuesView[Any]: + return self._read_memory().values() + + def items(self) -> ItemsView: + return self._read_memory().items() + + def pop(self, key: str, default: Optional[Any] = None): + with self._modify_db() as db: + if default is None: + return db.pop(key) + return db.pop(key, default) + + def update(self, other=(), /, **kwds): + with self._modify_db() as db: + db.update(other, **kwds) + + def setdefault(self, key: str, default: Optional[Any] = None): + with self._modify_db() as db: + return db.setdefault(key, default) + + def _save_memory(self, db: Dict[str, Any]) -> None: + data = self._serializer.dumps(db) + try: + self.shared_memory.buf[: len(data)] = data + except ValueError as exc: + raise ValueError("exceeds available storage") from exc + + def _read_memory(self) -> Dict[str, Any]: + return self._serializer.loads(self.shared_memory.buf.tobytes()) + + +if __name__ == '__main__': + import doctest + + doctest.testmod()