X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=collect%2Fshared_dict.py;fp=collect%2Fshared_dict.py;h=0d8e7c2f7a36aa5ddb7c54c72aecddbf56df71c3;hb=865825894beeedd47d26dd092d40bfee582f5475;hp=93aa452d50f9bb383d60662a3925ee51e438b644;hpb=5e09a33068fcdf6d43f12477dd943e108e11ae06;p=python_utils.git diff --git a/collect/shared_dict.py b/collect/shared_dict.py index 93aa452..0d8e7c2 100644 --- a/collect/shared_dict.py +++ b/collect/shared_dict.py @@ -30,7 +30,7 @@ This class is based on https://github.com/luizalabs/shared-memory-dict import pickle from contextlib import contextmanager from functools import wraps -from multiprocessing import shared_memory, Lock +from multiprocessing import shared_memory, RLock from typing import ( Any, Dict, @@ -64,25 +64,33 @@ class PickleSerializer: class SharedDict(object): NULL_BYTE = b'\x00' - MPLOCK = Lock() + MPLOCK = RLock() def __init__( self, name: str, - size: int, + size_bytes: Optional[int] = None, ) -> None: super().__init__() + self.name = name self._serializer = PickleSerializer() - self.shared_memory = self._get_or_create_memory_block(name, size) + self.shared_memory = self._get_or_create_memory_block(name, size_bytes) self._ensure_memory_initialization() + self.lock = RLock() + + def get_name(self): + return self.name def _get_or_create_memory_block( - self, name: str, size: int + self, + name: str, + size_bytes: Optional[int] = None, ) -> shared_memory.SharedMemory: try: return shared_memory.SharedMemory(name=name) except FileNotFoundError: - return shared_memory.SharedMemory(name=name, create=True, size=size) + assert size_bytes + return shared_memory.SharedMemory(name=name, create=True, size=size_bytes) def _ensure_memory_initialization(self): memory_is_empty = ( @@ -101,7 +109,6 @@ class SharedDict(object): return self.shared_memory.unlink() - @synchronized(MPLOCK) def clear(self) -> None: self._save_memory({}) @@ -109,12 +116,12 @@ class SharedDict(object): with self._modify_db() as db: return db.popitem() - @synchronized(MPLOCK) @contextmanager def _modify_db(self) -> Generator: - db = self._read_memory() - yield db - self._save_memory(db) + with SharedDict.MPLOCK: + db = self._read_memory() + yield db + self._save_memory(db) def __getitem__(self, key: str) -> Any: return self._read_memory()[key] @@ -181,14 +188,16 @@ class SharedDict(object): return db.setdefault(key, default) def _save_memory(self, db: Dict[str, Any]) -> None: - data = self._serializer.dumps(db) - try: - self.shared_memory.buf[: len(data)] = data - except ValueError as exc: - raise ValueError("exceeds available storage") from exc + with SharedDict.MPLOCK: + data = self._serializer.dumps(db) + try: + self.shared_memory.buf[: len(data)] = data + except ValueError as exc: + raise ValueError("exceeds available storage") from exc def _read_memory(self) -> Dict[str, Any]: - return self._serializer.loads(self.shared_memory.buf.tobytes()) + with SharedDict.MPLOCK: + return self._serializer.loads(self.shared_memory.buf.tobytes()) if __name__ == '__main__':