6 Copyright (c) 2020 LuizaLabs
8 Additions/Modifications Copyright (c) 2022 Scott Gasch
10 Permission is hereby granted, free of charge, to any person obtaining a copy
11 of this software and associated documentation files (the "Software"), to deal
12 in the Software without restriction, including without limitation the rights
13 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14 copies of the Software, and to permit persons to whom the Software is
15 furnished to do so, subject to the following conditions:
17 The above copyright notice and this permission notice shall be included in all
18 copies or substantial portions of the Software.
20 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 This class is based on https://github.com/luizalabs/shared-memory-dict.
29 For details about what is preserved from the original and what was changed
30 by Scott, see NOTICE at the root of this module.
34 from contextlib import contextmanager
35 from multiprocessing import RLock, shared_memory
49 class PickleSerializer:
50 """A serializer that uses pickling. Used to read/write bytes in the shared
51 memory region and interpret them as a dict."""
53 def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
55 return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
56 except pickle.PicklingError as e:
57 raise Exception from e
59 def loads(self, data: bytes) -> Dict[Hashable, Any]:
61 return pickle.loads(data)
62 except pickle.UnpicklingError as e:
63 raise Exception from e
66 # TODOs: profile the serializers and figure out the fastest one. Can
67 # we use a ChainMap to avoid the constant de/re-serialization of the
71 class SharedDict(object):
72 """This class emulates the dict container but uses a
73 Multiprocessing.SharedMemory region to back the dict such that it
74 can be read and written by multiple independent processes at the
75 same time. Because it constantly de/re-serializes the dict, it is
76 much slower than a normal dict.
85 name: Optional[str] = None,
86 size_bytes: Optional[int] = None,
89 Creates or attaches a shared dictionary back by a SharedMemory buffer.
90 For create semantics, a unique name (string) and a max dictionary size
91 (expressed in bytes) must be provided. For attach semantics, these are
94 The first process that creates the SharedDict is responsible for
95 (optionally) naming it and deciding the max size (in bytes) that
96 it may be. It does this via args to the c'tor.
98 Subsequent processes may safely omit name and size args.
101 assert size_bytes is None or size_bytes > 0
102 self._serializer = PickleSerializer()
103 self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
104 self._ensure_memory_initialization()
105 self.name = self.shared_memory.name
108 """Returns the name of the shared memory buffer backing the dict."""
111 def _get_or_create_memory_block(
113 name: Optional[str] = None,
114 size_bytes: Optional[int] = None,
115 ) -> shared_memory.SharedMemory:
117 return shared_memory.SharedMemory(name=name)
118 except FileNotFoundError:
119 assert size_bytes is not None
120 return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
122 def _ensure_memory_initialization(self):
123 with SharedDict.LOCK:
125 bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
130 def _write_memory(self, db: Dict[Hashable, Any]) -> None:
131 data = self._serializer.dumps(db)
132 with SharedDict.LOCK:
134 self.shared_memory.buf[: len(data)] = data
135 except ValueError as e:
136 raise ValueError("exceeds available storage") from e
138 def _read_memory(self) -> Dict[Hashable, Any]:
139 with SharedDict.LOCK:
140 return self._serializer.loads(self.shared_memory.buf.tobytes())
143 def _modify_dict(self):
144 with SharedDict.LOCK:
145 db = self._read_memory()
147 self._write_memory(db)
149 def close(self) -> None:
150 """Unmap the shared dict and memory behind it from this
151 process. Called by automatically __del__"""
152 if not hasattr(self, 'shared_memory'):
154 self.shared_memory.close()
156 def cleanup(self) -> None:
157 """Unlink the shared dict and memory behind it. Only the last process should
158 invoke this. Not called automatically."""
159 if not hasattr(self, 'shared_memory'):
161 with SharedDict.LOCK:
162 self.shared_memory.unlink()
164 def clear(self) -> None:
165 """Clear the dict."""
166 self._write_memory({})
168 def copy(self) -> Dict[Hashable, Any]:
169 """Returns a shallow copy of the dict."""
170 return self._read_memory()
172 def __getitem__(self, key: Hashable) -> Any:
173 return self._read_memory()[key]
175 def __setitem__(self, key: Hashable, value: Any) -> None:
176 with self._modify_dict() as db:
179 def __len__(self) -> int:
180 return len(self._read_memory())
182 def __delitem__(self, key: Hashable) -> None:
183 with self._modify_dict() as db:
186 def __iter__(self) -> Iterator[Hashable]:
187 return iter(self._read_memory())
189 def __reversed__(self) -> Iterator[Hashable]:
190 return reversed(self._read_memory())
192 def __del__(self) -> None:
195 def __contains__(self, key: Hashable) -> bool:
196 return key in self._read_memory()
198 def __eq__(self, other: Any) -> bool:
199 return self._read_memory() == other
201 def __ne__(self, other: Any) -> bool:
202 return self._read_memory() != other
204 def __str__(self) -> str:
205 return str(self._read_memory())
207 def __repr__(self) -> str:
208 return repr(self._read_memory())
210 def get(self, key: str, default: Optional[Any] = None) -> Any:
211 """Gets the value associated with key or a default."""
212 return self._read_memory().get(key, default)
214 def keys(self) -> KeysView[Hashable]:
215 return self._read_memory().keys()
217 def values(self) -> ValuesView[Any]:
218 return self._read_memory().values()
220 def items(self) -> ItemsView[Hashable, Any]:
221 return self._read_memory().items()
223 def popitem(self) -> Tuple[Hashable, Any]:
224 """Remove and return the last added item."""
225 with self._modify_dict() as db:
228 def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
229 """Remove and return the value associated with key or a default"""
230 with self._modify_dict() as db:
233 return db.pop(key, default)
235 def update(self, other=(), /, **kwds):
236 with self._modify_dict() as db:
237 db.update(other, **kwds)
239 def setdefault(self, key: Hashable, default: Optional[Any] = None):
240 with self._modify_dict() as db:
241 return db.setdefault(key, default)