X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=collect%2Fshared_dict.py;h=3207927ed2f550b6516bce0c1b72fd96d7581ba4;hb=532df2c5b57c7517dfb3dddd8c1358fbadf8baf3;hp=93aa452d50f9bb383d60662a3925ee51e438b644;hpb=b93993dc08103ac47301556f8dd6fb8b71ee2551;p=python_utils.git diff --git a/collect/shared_dict.py b/collect/shared_dict.py index 93aa452..3207927 100644 --- a/collect/shared_dict.py +++ b/collect/shared_dict.py @@ -4,7 +4,7 @@ The MIT License (MIT) Copyright (c) 2020 LuizaLabs -Additions Copyright (c) 2022 Scott Gasch +Additions/Modifications Copyright (c) 2022 Scott Gasch Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -29,21 +29,18 @@ This class is based on https://github.com/luizalabs/shared-memory-dict import pickle from contextlib import contextmanager -from functools import wraps -from multiprocessing import shared_memory, Lock +from multiprocessing import RLock, shared_memory from typing import ( Any, Dict, Generator, - KeysView, ItemsView, Iterator, + KeysView, Optional, ValuesView, ) -from decorator_utils import synchronized - class PickleSerializer: def dumps(self, obj: dict) -> bytes: @@ -64,30 +61,37 @@ class PickleSerializer: class SharedDict(object): NULL_BYTE = b'\x00' - MPLOCK = Lock() + MPLOCK = RLock() def __init__( self, name: str, - size: int, + size_bytes: Optional[int] = None, ) -> None: super().__init__() + self.name = name self._serializer = PickleSerializer() - self.shared_memory = self._get_or_create_memory_block(name, size) + assert size_bytes is None or size_bytes > 0 + self.shared_memory = self._get_or_create_memory_block(name, size_bytes) self._ensure_memory_initialization() + self.lock = RLock() + + def get_name(self): + return self.name def _get_or_create_memory_block( - self, name: str, size: int + self, + name: str, + size_bytes: Optional[int] = None, ) -> shared_memory.SharedMemory: try: return shared_memory.SharedMemory(name=name) except FileNotFoundError: - return shared_memory.SharedMemory(name=name, create=True, size=size) + assert size_bytes is not None + return shared_memory.SharedMemory(name=name, create=True, size=size_bytes) def _ensure_memory_initialization(self): - memory_is_empty = ( - bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b'' - ) + memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b'' if memory_is_empty: self.clear() @@ -99,9 +103,9 @@ class SharedDict(object): def cleanup(self) -> None: if not hasattr(self, 'shared_memory'): return - self.shared_memory.unlink() + with SharedDict.MPLOCK: + self.shared_memory.unlink() - @synchronized(MPLOCK) def clear(self) -> None: self._save_memory({}) @@ -109,12 +113,12 @@ class SharedDict(object): with self._modify_db() as db: return db.popitem() - @synchronized(MPLOCK) @contextmanager def _modify_db(self) -> Generator: - db = self._read_memory() - yield db - self._save_memory(db) + with SharedDict.MPLOCK: + db = self._read_memory() + yield db + self._save_memory(db) def __getitem__(self, key: str) -> Any: return self._read_memory()[key] @@ -181,14 +185,16 @@ class SharedDict(object): return db.setdefault(key, default) def _save_memory(self, db: Dict[str, Any]) -> None: - data = self._serializer.dumps(db) - try: - self.shared_memory.buf[: len(data)] = data - except ValueError as exc: - raise ValueError("exceeds available storage") from exc + with SharedDict.MPLOCK: + data = self._serializer.dumps(db) + try: + self.shared_memory.buf[: len(data)] = data + except ValueError as exc: + raise ValueError("exceeds available storage") from exc def _read_memory(self) -> Dict[str, Any]: - return self._serializer.loads(self.shared_memory.buf.tobytes()) + with SharedDict.MPLOCK: + return self._serializer.loads(self.shared_memory.buf.tobytes()) if __name__ == '__main__':