3 """The MIT License (MIT)
5 Copyright (c) 2020 LuizaLabs
7 Additions/Modifications Copyright (c) 2022 Scott Gasch
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27 This class is based on
28 https://github.com/luizalabs/shared-memory-dict. For details about
29 what is preserved from the original and what was changed by Scott, see
31 <https://wannabe.guru.org/gitweb/?p=pyutils.git;a=blob_plain;f=NOTICE;hb=HEAD>`_
32 at the root of this module.
37 from contextlib import contextmanager
38 from multiprocessing import RLock, shared_memory
52 class PickleSerializer:
53 """A serializer that uses pickling. Used to read/write bytes in the shared
54 memory region and interpret them as a dict."""
56 def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
58 return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
59 except pickle.PicklingError as e:
60 raise Exception from e
62 def loads(self, data: bytes) -> Dict[Hashable, Any]:
64 return pickle.loads(data)
65 except pickle.UnpicklingError as e:
66 raise Exception from e
69 # TODOs: profile the serializers and figure out the fastest one. Can
70 # we use a ChainMap to avoid the constant de/re-serialization of the
74 class SharedDict(object):
75 """This class emulates the dict container but uses a
76 Multiprocessing.SharedMemory region to back the dict such that it
77 can be read and written by multiple independent processes at the
78 same time. Because it constantly de/re-serializes the dict, it is
79 much slower than a normal dict.
88 name: Optional[str] = None,
89 size_bytes: Optional[int] = None,
92 Creates or attaches a shared dictionary back by a SharedMemory buffer.
93 For create semantics, a unique name (string) and a max dictionary size
94 (expressed in bytes) must be provided. For attach semantics, these are
97 The first process that creates the SharedDict is responsible for
98 (optionally) naming it and deciding the max size (in bytes) that
99 it may be. It does this via args to the c'tor.
101 Subsequent processes may safely omit name and size args.
104 name: the name of the shared dict, only required for initial caller
105 size_bytes: the maximum size of data storable in the shared dict,
106 only required for the first caller.
109 assert size_bytes is None or size_bytes > 0
110 self._serializer = PickleSerializer()
111 self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
112 self._ensure_memory_initialization()
113 self.name = self.shared_memory.name
118 The name of the shared memory buffer backing the dict.
122 def _get_or_create_memory_block(
124 name: Optional[str] = None,
125 size_bytes: Optional[int] = None,
126 ) -> shared_memory.SharedMemory:
127 """Internal helper."""
129 return shared_memory.SharedMemory(name=name)
130 except FileNotFoundError:
131 assert size_bytes is not None
132 return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
134 def _ensure_memory_initialization(self):
135 """Internal helper."""
136 with SharedDict.LOCK:
138 bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
143 def _write_memory(self, db: Dict[Hashable, Any]) -> None:
144 """Internal helper."""
145 data = self._serializer.dumps(db)
146 with SharedDict.LOCK:
148 self.shared_memory.buf[: len(data)] = data
149 except ValueError as e:
150 raise ValueError("exceeds available storage") from e
152 def _read_memory(self) -> Dict[Hashable, Any]:
153 """Internal helper."""
154 with SharedDict.LOCK:
155 return self._serializer.loads(self.shared_memory.buf.tobytes())
158 def _modify_dict(self):
159 """Internal helper."""
160 with SharedDict.LOCK:
161 db = self._read_memory()
163 self._write_memory(db)
165 def close(self) -> None:
166 """Unmap the shared dict and memory behind it from this
167 process. Called by automatically :meth:`__del__`.
169 if not hasattr(self, 'shared_memory'):
171 self.shared_memory.close()
173 def cleanup(self) -> None:
174 """Unlink the shared dict and memory behind it. Only the last process
175 should invoke this. Not called automatically."""
176 if not hasattr(self, 'shared_memory'):
178 with SharedDict.LOCK:
179 self.shared_memory.unlink()
181 def clear(self) -> None:
182 """Clears the shared dict."""
183 self._write_memory({})
185 def copy(self) -> Dict[Hashable, Any]:
188 A shallow copy of the shared dict.
190 return self._read_memory()
192 def __getitem__(self, key: Hashable) -> Any:
193 return self._read_memory()[key]
195 def __setitem__(self, key: Hashable, value: Any) -> None:
196 with self._modify_dict() as db:
199 def __len__(self) -> int:
200 return len(self._read_memory())
202 def __delitem__(self, key: Hashable) -> None:
203 with self._modify_dict() as db:
206 def __iter__(self) -> Iterator[Hashable]:
207 return iter(self._read_memory())
209 def __reversed__(self) -> Iterator[Hashable]:
210 return reversed(self._read_memory())
212 def __del__(self) -> None:
215 def __contains__(self, key: Hashable) -> bool:
216 return key in self._read_memory()
218 def __eq__(self, other: Any) -> bool:
219 return self._read_memory() == other
221 def __ne__(self, other: Any) -> bool:
222 return self._read_memory() != other
224 def __str__(self) -> str:
225 return str(self._read_memory())
227 def __repr__(self) -> str:
228 return repr(self._read_memory())
230 def get(self, key: str, default: Optional[Any] = None) -> Any:
233 key: the key to lookup
234 default: the value returned if key is not present
237 The value associated with key or a default.
239 return self._read_memory().get(key, default)
241 def keys(self) -> KeysView[Hashable]:
242 return self._read_memory().keys()
244 def values(self) -> ValuesView[Any]:
245 return self._read_memory().values()
247 def items(self) -> ItemsView[Hashable, Any]:
248 return self._read_memory().items()
250 def popitem(self) -> Tuple[Hashable, Any]:
251 """Remove and return the last added item."""
252 with self._modify_dict() as db:
255 def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
256 """Remove and return the value associated with key or a default"""
257 with self._modify_dict() as db:
260 return db.pop(key, default)
262 def update(self, other=(), /, **kwds):
263 with self._modify_dict() as db:
264 db.update(other, **kwds)
266 def setdefault(self, key: Hashable, default: Optional[Any] = None):
267 with self._modify_dict() as db:
268 return db.setdefault(key, default)