#!/usr/bin/env python3 """ The MIT License (MIT) Copyright (c) 2020 LuizaLabs Additions Copyright (c) 2022 Scott Gasch Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. This class is based on https://github.com/luizalabs/shared-memory-dict """ import pickle from contextlib import contextmanager from functools import wraps from multiprocessing import shared_memory, RLock from typing import ( Any, Dict, Generator, KeysView, ItemsView, Iterator, Optional, ValuesView, ) from decorator_utils import synchronized class PickleSerializer: def dumps(self, obj: dict) -> bytes: try: return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) except pickle.PicklingError as e: raise Exception(e) def loads(self, data: bytes) -> dict: try: return pickle.loads(data) except pickle.UnpicklingError as e: raise Exception(e) # TODO: protobuf serializer? class SharedDict(object): NULL_BYTE = b'\x00' MPLOCK = RLock() def __init__( self, name: str, size_bytes: Optional[int] = None, ) -> None: super().__init__() self.name = name self._serializer = PickleSerializer() self.shared_memory = self._get_or_create_memory_block(name, size_bytes) self._ensure_memory_initialization() self.lock = RLock() def get_name(self): return self.name def _get_or_create_memory_block( self, name: str, size_bytes: Optional[int] = None, ) -> shared_memory.SharedMemory: try: return shared_memory.SharedMemory(name=name) except FileNotFoundError: assert size_bytes return shared_memory.SharedMemory(name=name, create=True, size=size_bytes) def _ensure_memory_initialization(self): memory_is_empty = ( bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b'' ) if memory_is_empty: self.clear() def close(self) -> None: if not hasattr(self, 'shared_memory'): return self.shared_memory.close() def cleanup(self) -> None: if not hasattr(self, 'shared_memory'): return self.shared_memory.unlink() def clear(self) -> None: self._save_memory({}) def popitem(self, last: Optional[bool] = None) -> Any: with self._modify_db() as db: return db.popitem() @contextmanager def _modify_db(self) -> Generator: with SharedDict.MPLOCK: db = self._read_memory() yield db self._save_memory(db) def __getitem__(self, key: str) -> Any: return self._read_memory()[key] def __setitem__(self, key: str, value: Any) -> None: with self._modify_db() as db: db[key] = value def __len__(self) -> int: return len(self._read_memory()) def __delitem__(self, key: str) -> None: with self._modify_db() as db: del db[key] def __iter__(self) -> Iterator: return iter(self._read_memory()) def __reversed__(self): return reversed(self._read_memory()) def __del__(self) -> None: self.close() def __contains__(self, key: str) -> bool: return key in self._read_memory() def __eq__(self, other: Any) -> bool: return self._read_memory() == other def __ne__(self, other: Any) -> bool: return self._read_memory() != other def __str__(self) -> str: return str(self._read_memory()) def __repr__(self) -> str: return repr(self._read_memory()) def get(self, key: str, default: Optional[Any] = None) -> Any: return self._read_memory().get(key, default) def keys(self) -> KeysView[Any]: return self._read_memory().keys() def values(self) -> ValuesView[Any]: return self._read_memory().values() def items(self) -> ItemsView: return self._read_memory().items() def pop(self, key: str, default: Optional[Any] = None): with self._modify_db() as db: if default is None: return db.pop(key) return db.pop(key, default) def update(self, other=(), /, **kwds): with self._modify_db() as db: db.update(other, **kwds) def setdefault(self, key: str, default: Optional[Any] = None): with self._modify_db() as db: return db.setdefault(key, default) def _save_memory(self, db: Dict[str, Any]) -> None: with SharedDict.MPLOCK: data = self._serializer.dumps(db) try: self.shared_memory.buf[: len(data)] = data except ValueError as exc: raise ValueError("exceeds available storage") from exc def _read_memory(self) -> Dict[str, Any]: with SharedDict.MPLOCK: return self._serializer.loads(self.shared_memory.buf.tobytes()) if __name__ == '__main__': import doctest doctest.testmod()