Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / collect / shared_dict.py
index 1c70c3dc74bcbbb3d5be99d4a799584f6fe3a72e..3207927ed2f550b6516bce0c1b72fd96d7581ba4 100644 (file)
@@ -1,22 +1,46 @@
 #!/usr/bin/env python3
 
+"""
+The MIT License (MIT)
+
+Copyright (c) 2020 LuizaLabs
+Additions/Modifications Copyright (c) 2022 Scott Gasch
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+This class is based on https://github.com/luizalabs/shared-memory-dict
+"""
+
 import pickle
 from contextlib import contextmanager
-from functools import wraps
-from multiprocessing import shared_memory, Lock
+from multiprocessing import RLock, shared_memory
 from typing import (
     Any,
     Dict,
     Generator,
-    KeysView,
     ItemsView,
     Iterator,
+    KeysView,
     Optional,
     ValuesView,
 )
 
-from decorator_utils import synchronized
-
 
 class PickleSerializer:
     def dumps(self, obj: dict) -> bytes:
@@ -37,30 +61,37 @@ class PickleSerializer:
 
 class SharedDict(object):
     NULL_BYTE = b'\x00'
-    MPLOCK = Lock()
+    MPLOCK = RLock()
 
     def __init__(
         self,
         name: str,
-        size: int,
+        size_bytes: Optional[int] = None,
     ) -> None:
         super().__init__()
+        self.name = name
         self._serializer = PickleSerializer()
-        self.shared_memory = self._get_or_create_memory_block(name, size)
+        assert size_bytes is None or size_bytes > 0
+        self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
         self._ensure_memory_initialization()
+        self.lock = RLock()
+
+    def get_name(self):
+        return self.name
 
     def _get_or_create_memory_block(
-        self, name: str, size: int
+        self,
+        name: str,
+        size_bytes: Optional[int] = None,
     ) -> shared_memory.SharedMemory:
         try:
             return shared_memory.SharedMemory(name=name)
         except FileNotFoundError:
-            return shared_memory.SharedMemory(name=name, create=True, size=size)
+            assert size_bytes is not None
+            return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
 
     def _ensure_memory_initialization(self):
-        memory_is_empty = (
-            bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
-        )
+        memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
         if memory_is_empty:
             self.clear()
 
@@ -72,9 +103,9 @@ class SharedDict(object):
     def cleanup(self) -> None:
         if not hasattr(self, 'shared_memory'):
             return
-        self.shared_memory.unlink()
+        with SharedDict.MPLOCK:
+            self.shared_memory.unlink()
 
-    @synchronized(MPLOCK)
     def clear(self) -> None:
         self._save_memory({})
 
@@ -82,12 +113,12 @@ class SharedDict(object):
         with self._modify_db() as db:
             return db.popitem()
 
-    @synchronized(MPLOCK)
     @contextmanager
     def _modify_db(self) -> Generator:
-        db = self._read_memory()
-        yield db
-        self._save_memory(db)
+        with SharedDict.MPLOCK:
+            db = self._read_memory()
+            yield db
+            self._save_memory(db)
 
     def __getitem__(self, key: str) -> Any:
         return self._read_memory()[key]
@@ -154,14 +185,16 @@ class SharedDict(object):
             return db.setdefault(key, default)
 
     def _save_memory(self, db: Dict[str, Any]) -> None:
-        data = self._serializer.dumps(db)
-        try:
-            self.shared_memory.buf[: len(data)] = data
-        except ValueError as exc:
-            raise ValueError("exceeds available storage") from exc
+        with SharedDict.MPLOCK:
+            data = self._serializer.dumps(db)
+            try:
+                self.shared_memory.buf[: len(data)] = data
+            except ValueError as exc:
+                raise ValueError("exceeds available storage") from exc
 
     def _read_memory(self) -> Dict[str, Any]:
-        return self._serializer.loads(self.shared_memory.buf.tobytes())
+        with SharedDict.MPLOCK:
+            return self._serializer.loads(self.shared_memory.buf.tobytes())
 
 
 if __name__ == '__main__':