6 Copyright (c) 2020 LuizaLabs
7 Additions/Modifications Copyright (c) 2022 Scott Gasch
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27 This class is based on https://github.com/luizalabs/shared-memory-dict
31 from contextlib import contextmanager
32 from multiprocessing import RLock, shared_memory
46 class PickleSerializer:
47 """A serializer that uses pickling. Used to read/write bytes in the shared
48 memory region and interpret them as a dict."""
50 def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
52 return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
53 except pickle.PicklingError as e:
54 raise Exception from e
56 def loads(self, data: bytes) -> Dict[Hashable, Any]:
58 return pickle.loads(data)
59 except pickle.UnpicklingError as e:
60 raise Exception from e
63 # TODOs: profile the serializers and figure out the fastest one. Can
64 # we use a ChainMap to avoid the constant de/re-serialization of the
68 class SharedDict(object):
69 """This class emulates the dict container but uses a
70 Multiprocessing.SharedMemory region to back the dict such that it
71 can be read and written by multiple independent processes at the
72 same time. Because it constantly de/re-serializes the dict, it is
73 much slower than a normal dict.
82 name: Optional[str] = None,
83 size_bytes: Optional[int] = None,
86 Creates or attaches a shared dictionary back by a SharedMemory buffer.
87 For create semantics, a unique name (string) and a max dictionary size
88 (expressed in bytes) must be provided. For attach semantics, these are
91 The first process that creates the SharedDict is responsible for
92 (optionally) naming it and deciding the max size (in bytes) that
93 it may be. It does this via args to the c'tor.
95 Subsequent processes may safely omit name and size args.
98 assert size_bytes is None or size_bytes > 0
99 self._serializer = PickleSerializer()
100 self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
101 self._ensure_memory_initialization()
102 self.name = self.shared_memory.name
105 """Returns the name of the shared memory buffer backing the dict."""
108 def _get_or_create_memory_block(
110 name: Optional[str] = None,
111 size_bytes: Optional[int] = None,
112 ) -> shared_memory.SharedMemory:
114 return shared_memory.SharedMemory(name=name)
115 except FileNotFoundError:
116 assert size_bytes is not None
117 return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
119 def _ensure_memory_initialization(self):
120 with SharedDict.LOCK:
121 memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
125 def _write_memory(self, db: Dict[Hashable, Any]) -> None:
126 data = self._serializer.dumps(db)
127 with SharedDict.LOCK:
129 self.shared_memory.buf[: len(data)] = data
130 except ValueError as e:
131 raise ValueError("exceeds available storage") from e
133 def _read_memory(self) -> Dict[Hashable, Any]:
134 with SharedDict.LOCK:
135 return self._serializer.loads(self.shared_memory.buf.tobytes())
138 def _modify_dict(self):
139 with SharedDict.LOCK:
140 db = self._read_memory()
142 self._write_memory(db)
144 def close(self) -> None:
145 """Unmap the shared dict and memory behind it from this
146 process. Called by automatically __del__"""
147 if not hasattr(self, 'shared_memory'):
149 self.shared_memory.close()
151 def cleanup(self) -> None:
152 """Unlink the shared dict and memory behind it. Only the last process should
153 invoke this. Not called automatically."""
154 if not hasattr(self, 'shared_memory'):
156 with SharedDict.LOCK:
157 self.shared_memory.unlink()
159 def clear(self) -> None:
160 """Clear the dict."""
161 self._write_memory({})
163 def copy(self) -> Dict[Hashable, Any]:
164 """Returns a shallow copy of the dict."""
165 return self._read_memory()
167 def __getitem__(self, key: Hashable) -> Any:
168 return self._read_memory()[key]
170 def __setitem__(self, key: Hashable, value: Any) -> None:
171 with self._modify_dict() as db:
174 def __len__(self) -> int:
175 return len(self._read_memory())
177 def __delitem__(self, key: Hashable) -> None:
178 with self._modify_dict() as db:
181 def __iter__(self) -> Iterator[Hashable]:
182 return iter(self._read_memory())
184 def __reversed__(self) -> Iterator[Hashable]:
185 return reversed(self._read_memory())
187 def __del__(self) -> None:
190 def __contains__(self, key: Hashable) -> bool:
191 return key in self._read_memory()
193 def __eq__(self, other: Any) -> bool:
194 return self._read_memory() == other
196 def __ne__(self, other: Any) -> bool:
197 return self._read_memory() != other
199 def __str__(self) -> str:
200 return str(self._read_memory())
202 def __repr__(self) -> str:
203 return repr(self._read_memory())
205 def get(self, key: str, default: Optional[Any] = None) -> Any:
206 """Gets the value associated with key or a default."""
207 return self._read_memory().get(key, default)
209 def keys(self) -> KeysView[Hashable]:
210 return self._read_memory().keys()
212 def values(self) -> ValuesView[Any]:
213 return self._read_memory().values()
215 def items(self) -> ItemsView[Hashable, Any]:
216 return self._read_memory().items()
218 def popitem(self) -> Tuple[Hashable, Any]:
219 """Remove and return the last added item."""
220 with self._modify_dict() as db:
223 def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
224 """Remove and return the value associated with key or a default"""
225 with self._modify_dict() as db:
228 return db.pop(key, default)
230 def update(self, other=(), /, **kwds):
231 with self._modify_dict() as db:
232 db.update(other, **kwds)
234 def setdefault(self, key: Hashable, default: Optional[Any] = None):
235 with self._modify_dict() as db:
236 return db.setdefault(key, default)
239 if __name__ == '__main__':