2c05809749a7904fdd37bcf6fd8a737bdab0c929
[pyutils.git] / src / pyutils / collectionz / shared_dict.py
1 #!/usr/bin/env python3
2
3 """The MIT License (MIT)
4
5 Copyright (c) 2020 LuizaLabs
6
7 Additions/Modifications Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on
28 https://github.com/luizalabs/shared-memory-dict.  For details about
29 what is preserved from the original and what was changed by Scott, see
30 `NOTICE
31 <https://wannabe.guru.org/gitweb/?p=pyutils.git;a=blob_plain;f=NOTICE;hb=HEAD>`_
32 at the root of this module.
33
34 """
35
36 import pickle
37 from contextlib import contextmanager
38 from multiprocessing import RLock, shared_memory
39 from typing import (
40     Any,
41     Dict,
42     Hashable,
43     ItemsView,
44     Iterator,
45     KeysView,
46     Optional,
47     Tuple,
48     ValuesView,
49 )
50
51
52 class PickleSerializer:
53     """A serializer that uses pickling.  Used to read/write bytes in the shared
54     memory region and interpret them as a dict."""
55
56     def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
57         try:
58             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
59         except pickle.PicklingError as e:
60             raise Exception from e
61
62     def loads(self, data: bytes) -> Dict[Hashable, Any]:
63         try:
64             return pickle.loads(data)
65         except pickle.UnpicklingError as e:
66             raise Exception from e
67
68
69 # TODOs: profile the serializers and figure out the fastest one.  Can
70 # we use a ChainMap to avoid the constant de/re-serialization of the
71 # whole thing?
72
73
74 class SharedDict(object):
75     """This class emulates the dict container but uses a
76     Multiprocessing.SharedMemory region to back the dict such that it
77     can be read and written by multiple independent processes at the
78     same time.  Because it constantly de/re-serializes the dict, it is
79     much slower than a normal dict.
80
81     """
82
83     NULL_BYTE = b'\x00'
84     LOCK = RLock()
85
86     def __init__(
87         self,
88         name: Optional[str] = None,
89         size_bytes: Optional[int] = None,
90     ) -> None:
91         """
92         Creates or attaches a shared dictionary back by a SharedMemory buffer.
93         For create semantics, a unique name (string) and a max dictionary size
94         (expressed in bytes) must be provided.  For attach semantics, these are
95         ignored.
96
97         The first process that creates the SharedDict is responsible for
98         (optionally) naming it and deciding the max size (in bytes) that
99         it may be.  It does this via args to the c'tor.
100
101         Subsequent processes may safely omit name and size args.
102
103         Args:
104             name: the name of the shared dict, only required for initial caller
105             size_bytes: the maximum size of data storable in the shared dict,
106                 only required for the first caller.
107
108         """
109         assert size_bytes is None or size_bytes > 0
110         self._serializer = PickleSerializer()
111         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
112         self._ensure_memory_initialization()
113         self.name = self.shared_memory.name
114
115     def get_name(self):
116         """
117         Returns:
118             The name of the shared memory buffer backing the dict.
119         """
120         return self.name
121
122     def _get_or_create_memory_block(
123         self,
124         name: Optional[str] = None,
125         size_bytes: Optional[int] = None,
126     ) -> shared_memory.SharedMemory:
127         """Internal helper."""
128         try:
129             return shared_memory.SharedMemory(name=name)
130         except FileNotFoundError:
131             assert size_bytes is not None
132             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
133
134     def _ensure_memory_initialization(self):
135         """Internal helper."""
136         with SharedDict.LOCK:
137             memory_is_empty = (
138                 bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
139             )
140             if memory_is_empty:
141                 self.clear()
142
143     def _write_memory(self, db: Dict[Hashable, Any]) -> None:
144         """Internal helper."""
145         data = self._serializer.dumps(db)
146         with SharedDict.LOCK:
147             try:
148                 self.shared_memory.buf[: len(data)] = data
149             except ValueError as e:
150                 raise ValueError("exceeds available storage") from e
151
152     def _read_memory(self) -> Dict[Hashable, Any]:
153         """Internal helper."""
154         with SharedDict.LOCK:
155             return self._serializer.loads(self.shared_memory.buf.tobytes())
156
157     @contextmanager
158     def _modify_dict(self):
159         """Internal helper."""
160         with SharedDict.LOCK:
161             db = self._read_memory()
162             yield db
163             self._write_memory(db)
164
165     def close(self) -> None:
166         """Unmap the shared dict and memory behind it from this
167         process.  Called by automatically :meth:`__del__`.
168         """
169         if not hasattr(self, 'shared_memory'):
170             return
171         self.shared_memory.close()
172
173     def cleanup(self) -> None:
174         """Unlink the shared dict and memory behind it.  Only the last process
175         should invoke this.  Not called automatically."""
176         if not hasattr(self, 'shared_memory'):
177             return
178         with SharedDict.LOCK:
179             self.shared_memory.unlink()
180
181     def clear(self) -> None:
182         """Clears the shared dict."""
183         self._write_memory({})
184
185     def copy(self) -> Dict[Hashable, Any]:
186         """
187         Returns:
188             A shallow copy of the shared dict.
189         """
190         return self._read_memory()
191
192     def __getitem__(self, key: Hashable) -> Any:
193         return self._read_memory()[key]
194
195     def __setitem__(self, key: Hashable, value: Any) -> None:
196         with self._modify_dict() as db:
197             db[key] = value
198
199     def __len__(self) -> int:
200         return len(self._read_memory())
201
202     def __delitem__(self, key: Hashable) -> None:
203         with self._modify_dict() as db:
204             del db[key]
205
206     def __iter__(self) -> Iterator[Hashable]:
207         return iter(self._read_memory())
208
209     def __reversed__(self) -> Iterator[Hashable]:
210         return reversed(self._read_memory())
211
212     def __del__(self) -> None:
213         self.close()
214
215     def __contains__(self, key: Hashable) -> bool:
216         return key in self._read_memory()
217
218     def __eq__(self, other: Any) -> bool:
219         return self._read_memory() == other
220
221     def __ne__(self, other: Any) -> bool:
222         return self._read_memory() != other
223
224     def __str__(self) -> str:
225         return str(self._read_memory())
226
227     def __repr__(self) -> str:
228         return repr(self._read_memory())
229
230     def get(self, key: str, default: Optional[Any] = None) -> Any:
231         """
232         Args:
233             key: the key to lookup
234             default: the value returned if key is not present
235
236         Returns:
237             The value associated with key or a default.
238         """
239         return self._read_memory().get(key, default)
240
241     def keys(self) -> KeysView[Hashable]:
242         return self._read_memory().keys()
243
244     def values(self) -> ValuesView[Any]:
245         return self._read_memory().values()
246
247     def items(self) -> ItemsView[Hashable, Any]:
248         return self._read_memory().items()
249
250     def popitem(self) -> Tuple[Hashable, Any]:
251         """Remove and return the last added item."""
252         with self._modify_dict() as db:
253             return db.popitem()
254
255     def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
256         """Remove and return the value associated with key or a default"""
257         with self._modify_dict() as db:
258             if default is None:
259                 return db.pop(key)
260             return db.pop(key, default)
261
262     def update(self, other=(), /, **kwds):
263         with self._modify_dict() as db:
264             db.update(other, **kwds)
265
266     def setdefault(self, key: Hashable, default: Optional[Any] = None):
267         with self._modify_dict() as db:
268             return db.setdefault(key, default)