Merge simple and typing.
[pyutils.git] / src / pyutils / collectionz / shared_dict.py
1 #!/usr/bin/env python3
2
3 """The MIT License (MIT)
4
5 Copyright (c) 2020 LuizaLabs
6
7 Additions/Modifications Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on
28 https://github.com/luizalabs/shared-memory-dict.  For details about
29 what is preserved from the original and what was changed by Scott, see
30 `NOTICE
31 <https://wannabe.guru.org/gitweb/?p=pyutils.git;a=blob_plain;f=NOTICE;hb=HEAD>`_
32 at the root of this module.
33
34 """
35
36 import pickle
37 from contextlib import contextmanager
38 from multiprocessing import RLock, shared_memory
39 from typing import (
40     Any,
41     Dict,
42     Hashable,
43     ItemsView,
44     Iterator,
45     KeysView,
46     Optional,
47     Tuple,
48     ValuesView,
49 )
50
51 from pyutils.typez.typing import Closable
52
53
54 class PickleSerializer:
55     """A serializer that uses pickling.  Used to read/write bytes in the shared
56     memory region and interpret them as a dict."""
57
58     def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
59         try:
60             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
61         except pickle.PicklingError as e:
62             raise Exception from e
63
64     def loads(self, data: bytes) -> Dict[Hashable, Any]:
65         try:
66             return pickle.loads(data)
67         except pickle.UnpicklingError as e:
68             raise Exception from e
69
70
71 # TODOs: profile the serializers and figure out the fastest one.  Can
72 # we use a ChainMap to avoid the constant de/re-serialization of the
73 # whole thing?
74
75
76 class SharedDict(Closable):
77     """This class emulates the dict container but uses a
78     `Multiprocessing.SharedMemory` region to back the dict such that it
79     can be read and written by multiple independent processes at the
80     same time.  Because it constantly de/re-serializes the dict, it is
81     much slower than a normal dict.
82
83     Example usage... one process should set up the shared memory::
84
85         from pyutils.collectionz.shared_dict import SharedDict
86
87         shared_memory_id = 'SharedDictIdentifier'
88         shared_memory_size_bytes = 4096
89         shared_memory = SharedDict(shared_memory_id, shared_memory_size_bytes)
90
91     Other processes can then attach to the shared memory by
92     referencing its name.  Don't try to pass the :class:`SharedDict` itself to
93     a child process.  Rather, just pass its name string.  You can create
94     child processes any way that Python supports.  The
95     `wordle example <https://wannabe.guru.org/gitweb/?p=pyutils.git;a=blob_plain;f=examples/wordle/wordle.py;h=df9874ee0b309e7a70a5a7c8900629869def3928;hb=HEAD>`__ uses the
96     parallelize framework with `SharedDict` but a simple `subprocess.run`,
97     `exec_utils`, `ProcessExecutor`, whatever::
98
99         from pyutils import exec_utils
100
101         processes = []
102         for i in range(10):
103             processes.append(
104                 exec_utils.cmd_in_background(
105                     f'myhelper.py --number {i} --shared_memory={shared_memory_id}'
106                 )
107             )
108
109     In the child process, attach the already created :class:`SharedDict`
110     using its name.  A size is not necessary when attaching to an
111     already created shared memory region -- it cannot be resized after
112     creation.  The name must be the same exact name that was used to
113     create it originally::
114
115         from pyutils.collectionz.shared_dict import SharedDict
116
117         shared_memory_id = config.config['shared_memory']
118         shared_memory = SharedDict(shared_memory_id)
119
120     The children processes (and parent process, also) can now just use
121     the shared memory like a normal `dict`::
122
123         if shared_memory[work_id] is None:
124             result = do_expensive_work(work_id)
125             shared_memory[work_id] = result
126
127     .. note::
128
129         It's pretty slow to mutate data in the shared memory.  First,
130         it needs to acquire an exclusive lock.  Second, it essentially
131         pickles an entire dict into the shared memory region.  So this
132         is not a data structure that is going to win awards for speed.
133         But it is a very convenient way to have a shared cache, for
134         example.  See the wordle example for a real life program using
135         `SharedDict` this way.  It basically saves the result of large
136         computations in a `SharedDict` thereby allowing all threads to
137         avoid recomputing that same expensive computation.  In this
138         scenario the slowness of the dict writes are more than paid
139         for by the avoidence of duplicated, expensive work.
140
141     Finally, someone (likely the main process) should call the :meth:`cleanup`
142     method when the shared memory region is no longer needed::
143
144         shared_memory.cleanup()
145
146     See also the `shared_dict_test.py <https://wannabe.guru.org/gitweb/?p=pyutils.git;a=blob_plain;f=tests/collectionz/shared_dict_test.py;h=0a684f4835554553018cefbc114034c2dc405794;hb=HEAD>`__ for an
147     example of using this class.
148
149     ---
150     """
151
152     NULL_BYTE = b"\x00"
153     LOCK = RLock()
154
155     def __init__(
156         self,
157         name: Optional[str] = None,
158         size_bytes: Optional[int] = None,
159     ) -> None:
160         """Creates or attaches a shared dictionary back by a
161         :class:`SharedMemory` buffer.  For create semantics, a unique
162         name (string) and a max dictionary size (expressed in bytes)
163         must be provided.  For attach semantics size is ignored.
164
165         .. warning::
166
167             Size is ignored on attach operations.  The size of the
168             shared memory region cannot be changed once it has been
169             created.
170
171         The first process that creates the :class:`SharedDict` is
172         responsible for (optionally) naming it and deciding the max
173         size (in bytes) that it may be.  It does this via args to the
174         c'tor.
175
176         Subsequent processes may safely the size arg.
177
178         Args:
179             name: the name of the shared dict, only required for initial caller
180             size_bytes: the maximum size of data storable in the shared dict,
181                 only required for the first caller.
182
183         """
184         assert size_bytes is None or size_bytes > 0
185         self._serializer = PickleSerializer()
186         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
187         self._ensure_memory_initialization()
188         self.name = self.shared_memory.name
189
190     def get_name(self):
191         """
192         Returns:
193             The name of the shared memory buffer backing the dict.
194         """
195         return self.name
196
197     def _get_or_create_memory_block(
198         self,
199         name: Optional[str] = None,
200         size_bytes: Optional[int] = None,
201     ) -> shared_memory.SharedMemory:
202         """Internal helper."""
203         try:
204             return shared_memory.SharedMemory(name=name)
205         except FileNotFoundError:
206             assert size_bytes is not None
207             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
208
209     def _ensure_memory_initialization(self):
210         """Internal helper."""
211         with SharedDict.LOCK:
212             memory_is_empty = (
213                 bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b""
214             )
215             if memory_is_empty:
216                 self.clear()
217
218     def _write_memory(self, db: Dict[Hashable, Any]) -> None:
219         """Internal helper."""
220         data = self._serializer.dumps(db)
221         with SharedDict.LOCK:
222             try:
223                 self.shared_memory.buf[: len(data)] = data
224             except ValueError as e:
225                 raise ValueError("exceeds available storage") from e
226
227     def _read_memory(self) -> Dict[Hashable, Any]:
228         """Internal helper."""
229         with SharedDict.LOCK:
230             return self._serializer.loads(self.shared_memory.buf.tobytes())
231
232     @contextmanager
233     def _modify_dict(self):
234         """Internal helper."""
235         with SharedDict.LOCK:
236             db = self._read_memory()
237             yield db
238             self._write_memory(db)
239
240     def close(self) -> None:
241         """Unmap the shared dict and memory behind it from this
242         process.  Called by automatically :meth:`__del__`.
243         """
244         if not hasattr(self, "shared_memory"):
245             return
246         self.shared_memory.close()
247
248     def cleanup(self) -> None:
249         """Unlink the shared dict and memory behind it.  Only the last process
250         should invoke this.  Not called automatically."""
251         if not hasattr(self, "shared_memory"):
252             return
253         with SharedDict.LOCK:
254             self.shared_memory.unlink()
255
256     def clear(self) -> None:
257         """Clears the shared dict."""
258         self._write_memory({})
259
260     def copy(self) -> Dict[Hashable, Any]:
261         """
262         Returns:
263             A shallow copy of the shared dict.
264         """
265         return self._read_memory()
266
267     def __getitem__(self, key: Hashable) -> Any:
268         return self._read_memory()[key]
269
270     def __setitem__(self, key: Hashable, value: Any) -> None:
271         with self._modify_dict() as db:
272             db[key] = value
273
274     def __len__(self) -> int:
275         return len(self._read_memory())
276
277     def __delitem__(self, key: Hashable) -> None:
278         with self._modify_dict() as db:
279             del db[key]
280
281     def __iter__(self) -> Iterator[Hashable]:
282         return iter(self._read_memory())
283
284     def __reversed__(self) -> Iterator[Hashable]:
285         return reversed(self._read_memory())
286
287     def __del__(self) -> None:
288         self.close()
289
290     def __contains__(self, key: Hashable) -> bool:
291         return key in self._read_memory()
292
293     def __eq__(self, other: Any) -> bool:
294         return self._read_memory() == other
295
296     def __ne__(self, other: Any) -> bool:
297         return self._read_memory() != other
298
299     def __str__(self) -> str:
300         return str(self._read_memory())
301
302     def __repr__(self) -> str:
303         return repr(self._read_memory())
304
305     def get(self, key: str, default: Optional[Any] = None) -> Any:
306         """
307         Args:
308             key: the key to lookup
309             default: the value returned if key is not present
310
311         Returns:
312             The value associated with key or a default.
313         """
314         return self._read_memory().get(key, default)
315
316     def keys(self) -> KeysView[Hashable]:
317         return self._read_memory().keys()
318
319     def values(self) -> ValuesView[Any]:
320         return self._read_memory().values()
321
322     def items(self) -> ItemsView[Hashable, Any]:
323         return self._read_memory().items()
324
325     def popitem(self) -> Tuple[Hashable, Any]:
326         """Remove and return the last added item."""
327         with self._modify_dict() as db:
328             return db.popitem()
329
330     def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
331         """Remove and return the value associated with key or a default"""
332         with self._modify_dict() as db:
333             if default is None:
334                 return db.pop(key)
335             return db.pop(key, default)
336
337     def update(self, other=(), /, **kwds):
338         with self._modify_dict() as db:
339             db.update(other, **kwds)
340
341     def setdefault(self, key: Hashable, default: Optional[Any] = None):
342         with self._modify_dict() as db:
343             return db.setdefault(key, default)