6 Copyright (c) 2020 LuizaLabs
7 Additions/Modifications Copyright (c) 2022 Scott Gasch
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27 This class is based on https://github.com/luizalabs/shared-memory-dict.
28 For details about what is preserved from the original and what was changed
29 by Scott, see NOTICE at the root of this module.
33 from contextlib import contextmanager
34 from multiprocessing import RLock, shared_memory
48 class PickleSerializer:
49 """A serializer that uses pickling. Used to read/write bytes in the shared
50 memory region and interpret them as a dict."""
52 def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
54 return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
55 except pickle.PicklingError as e:
56 raise Exception from e
58 def loads(self, data: bytes) -> Dict[Hashable, Any]:
60 return pickle.loads(data)
61 except pickle.UnpicklingError as e:
62 raise Exception from e
65 # TODOs: profile the serializers and figure out the fastest one. Can
66 # we use a ChainMap to avoid the constant de/re-serialization of the
70 class SharedDict(object):
71 """This class emulates the dict container but uses a
72 Multiprocessing.SharedMemory region to back the dict such that it
73 can be read and written by multiple independent processes at the
74 same time. Because it constantly de/re-serializes the dict, it is
75 much slower than a normal dict.
84 name: Optional[str] = None,
85 size_bytes: Optional[int] = None,
88 Creates or attaches a shared dictionary back by a SharedMemory buffer.
89 For create semantics, a unique name (string) and a max dictionary size
90 (expressed in bytes) must be provided. For attach semantics, these are
93 The first process that creates the SharedDict is responsible for
94 (optionally) naming it and deciding the max size (in bytes) that
95 it may be. It does this via args to the c'tor.
97 Subsequent processes may safely omit name and size args.
100 assert size_bytes is None or size_bytes > 0
101 self._serializer = PickleSerializer()
102 self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
103 self._ensure_memory_initialization()
104 self.name = self.shared_memory.name
107 """Returns the name of the shared memory buffer backing the dict."""
110 def _get_or_create_memory_block(
112 name: Optional[str] = None,
113 size_bytes: Optional[int] = None,
114 ) -> shared_memory.SharedMemory:
116 return shared_memory.SharedMemory(name=name)
117 except FileNotFoundError:
118 assert size_bytes is not None
119 return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
121 def _ensure_memory_initialization(self):
122 with SharedDict.LOCK:
124 bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
129 def _write_memory(self, db: Dict[Hashable, Any]) -> None:
130 data = self._serializer.dumps(db)
131 with SharedDict.LOCK:
133 self.shared_memory.buf[: len(data)] = data
134 except ValueError as e:
135 raise ValueError("exceeds available storage") from e
137 def _read_memory(self) -> Dict[Hashable, Any]:
138 with SharedDict.LOCK:
139 return self._serializer.loads(self.shared_memory.buf.tobytes())
142 def _modify_dict(self):
143 with SharedDict.LOCK:
144 db = self._read_memory()
146 self._write_memory(db)
148 def close(self) -> None:
149 """Unmap the shared dict and memory behind it from this
150 process. Called by automatically __del__"""
151 if not hasattr(self, 'shared_memory'):
153 self.shared_memory.close()
155 def cleanup(self) -> None:
156 """Unlink the shared dict and memory behind it. Only the last process should
157 invoke this. Not called automatically."""
158 if not hasattr(self, 'shared_memory'):
160 with SharedDict.LOCK:
161 self.shared_memory.unlink()
163 def clear(self) -> None:
164 """Clear the dict."""
165 self._write_memory({})
167 def copy(self) -> Dict[Hashable, Any]:
168 """Returns a shallow copy of the dict."""
169 return self._read_memory()
171 def __getitem__(self, key: Hashable) -> Any:
172 return self._read_memory()[key]
174 def __setitem__(self, key: Hashable, value: Any) -> None:
175 with self._modify_dict() as db:
178 def __len__(self) -> int:
179 return len(self._read_memory())
181 def __delitem__(self, key: Hashable) -> None:
182 with self._modify_dict() as db:
185 def __iter__(self) -> Iterator[Hashable]:
186 return iter(self._read_memory())
188 def __reversed__(self) -> Iterator[Hashable]:
189 return reversed(self._read_memory())
191 def __del__(self) -> None:
194 def __contains__(self, key: Hashable) -> bool:
195 return key in self._read_memory()
197 def __eq__(self, other: Any) -> bool:
198 return self._read_memory() == other
200 def __ne__(self, other: Any) -> bool:
201 return self._read_memory() != other
203 def __str__(self) -> str:
204 return str(self._read_memory())
206 def __repr__(self) -> str:
207 return repr(self._read_memory())
209 def get(self, key: str, default: Optional[Any] = None) -> Any:
210 """Gets the value associated with key or a default."""
211 return self._read_memory().get(key, default)
213 def keys(self) -> KeysView[Hashable]:
214 return self._read_memory().keys()
216 def values(self) -> ValuesView[Any]:
217 return self._read_memory().values()
219 def items(self) -> ItemsView[Hashable, Any]:
220 return self._read_memory().items()
222 def popitem(self) -> Tuple[Hashable, Any]:
223 """Remove and return the last added item."""
224 with self._modify_dict() as db:
227 def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
228 """Remove and return the value associated with key or a default"""
229 with self._modify_dict() as db:
232 return db.pop(key, default)
234 def update(self, other=(), /, **kwds):
235 with self._modify_dict() as db:
236 db.update(other, **kwds)
238 def setdefault(self, key: Hashable, default: Optional[Any] = None):
239 with self._modify_dict() as db:
240 return db.setdefault(key, default)