6 Copyright (c) 2020 LuizaLabs
7 Additions Copyright (c) 2022 Scott Gasch
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27 This class is based on https://github.com/luizalabs/shared-memory-dict
31 from contextlib import contextmanager
32 from functools import wraps
33 from multiprocessing import RLock, shared_memory
45 from decorator_utils import synchronized
48 class PickleSerializer:
49 def dumps(self, obj: dict) -> bytes:
51 return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
52 except pickle.PicklingError as e:
55 def loads(self, data: bytes) -> dict:
57 return pickle.loads(data)
58 except pickle.UnpicklingError as e:
62 # TODO: protobuf serializer?
65 class SharedDict(object):
72 size_bytes: Optional[int] = None,
76 self._serializer = PickleSerializer()
77 assert size_bytes is None or size_bytes > 0
78 self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
79 self._ensure_memory_initialization()
85 def _get_or_create_memory_block(
88 size_bytes: Optional[int] = None,
89 ) -> shared_memory.SharedMemory:
91 return shared_memory.SharedMemory(name=name)
92 except FileNotFoundError:
93 assert size_bytes is not None
94 return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
96 def _ensure_memory_initialization(self):
98 bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
103 def close(self) -> None:
104 if not hasattr(self, 'shared_memory'):
106 self.shared_memory.close()
108 def cleanup(self) -> None:
109 if not hasattr(self, 'shared_memory'):
111 with SharedDict.MPLOCK:
112 self.shared_memory.unlink()
114 def clear(self) -> None:
115 self._save_memory({})
117 def popitem(self, last: Optional[bool] = None) -> Any:
118 with self._modify_db() as db:
122 def _modify_db(self) -> Generator:
123 with SharedDict.MPLOCK:
124 db = self._read_memory()
126 self._save_memory(db)
128 def __getitem__(self, key: str) -> Any:
129 return self._read_memory()[key]
131 def __setitem__(self, key: str, value: Any) -> None:
132 with self._modify_db() as db:
135 def __len__(self) -> int:
136 return len(self._read_memory())
138 def __delitem__(self, key: str) -> None:
139 with self._modify_db() as db:
142 def __iter__(self) -> Iterator:
143 return iter(self._read_memory())
145 def __reversed__(self):
146 return reversed(self._read_memory())
148 def __del__(self) -> None:
151 def __contains__(self, key: str) -> bool:
152 return key in self._read_memory()
154 def __eq__(self, other: Any) -> bool:
155 return self._read_memory() == other
157 def __ne__(self, other: Any) -> bool:
158 return self._read_memory() != other
160 def __str__(self) -> str:
161 return str(self._read_memory())
163 def __repr__(self) -> str:
164 return repr(self._read_memory())
166 def get(self, key: str, default: Optional[Any] = None) -> Any:
167 return self._read_memory().get(key, default)
169 def keys(self) -> KeysView[Any]:
170 return self._read_memory().keys()
172 def values(self) -> ValuesView[Any]:
173 return self._read_memory().values()
175 def items(self) -> ItemsView:
176 return self._read_memory().items()
178 def pop(self, key: str, default: Optional[Any] = None):
179 with self._modify_db() as db:
182 return db.pop(key, default)
184 def update(self, other=(), /, **kwds):
185 with self._modify_db() as db:
186 db.update(other, **kwds)
188 def setdefault(self, key: str, default: Optional[Any] = None):
189 with self._modify_db() as db:
190 return db.setdefault(key, default)
192 def _save_memory(self, db: Dict[str, Any]) -> None:
193 with SharedDict.MPLOCK:
194 data = self._serializer.dumps(db)
196 self.shared_memory.buf[: len(data)] = data
197 except ValueError as exc:
198 raise ValueError("exceeds available storage") from exc
200 def _read_memory(self) -> Dict[str, Any]:
201 with SharedDict.MPLOCK:
202 return self._serializer.loads(self.shared_memory.buf.tobytes())
205 if __name__ == '__main__':