Improve documentation.
[pyutils.git] / src / pyutils / collectionz / shared_dict.py
1 #!/usr/bin/env python3
2
3 """
4 The MIT License (MIT)
5
6 Copyright (c) 2020 LuizaLabs
7
8 Additions/Modifications Copyright (c) 2022 Scott Gasch
9
10 Permission is hereby granted, free of charge, to any person obtaining a copy
11 of this software and associated documentation files (the "Software"), to deal
12 in the Software without restriction, including without limitation the rights
13 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14 copies of the Software, and to permit persons to whom the Software is
15 furnished to do so, subject to the following conditions:
16
17 The above copyright notice and this permission notice shall be included in all
18 copies or substantial portions of the Software.
19
20 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 SOFTWARE.
27
28 This class is based on https://github.com/luizalabs/shared-memory-dict.
29 For details about what is preserved from the original and what was changed
30 by Scott, see NOTICE at the root of this module.
31 """
32
33 import pickle
34 from contextlib import contextmanager
35 from multiprocessing import RLock, shared_memory
36 from typing import (
37     Any,
38     Dict,
39     Hashable,
40     ItemsView,
41     Iterator,
42     KeysView,
43     Optional,
44     Tuple,
45     ValuesView,
46 )
47
48
49 class PickleSerializer:
50     """A serializer that uses pickling.  Used to read/write bytes in the shared
51     memory region and interpret them as a dict."""
52
53     def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
54         try:
55             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
56         except pickle.PicklingError as e:
57             raise Exception from e
58
59     def loads(self, data: bytes) -> Dict[Hashable, Any]:
60         try:
61             return pickle.loads(data)
62         except pickle.UnpicklingError as e:
63             raise Exception from e
64
65
66 # TODOs: profile the serializers and figure out the fastest one.  Can
67 # we use a ChainMap to avoid the constant de/re-serialization of the
68 # whole thing?
69
70
71 class SharedDict(object):
72     """This class emulates the dict container but uses a
73     Multiprocessing.SharedMemory region to back the dict such that it
74     can be read and written by multiple independent processes at the
75     same time.  Because it constantly de/re-serializes the dict, it is
76     much slower than a normal dict.
77
78     """
79
80     NULL_BYTE = b'\x00'
81     LOCK = RLock()
82
83     def __init__(
84         self,
85         name: Optional[str] = None,
86         size_bytes: Optional[int] = None,
87     ) -> None:
88         """
89         Creates or attaches a shared dictionary back by a SharedMemory buffer.
90         For create semantics, a unique name (string) and a max dictionary size
91         (expressed in bytes) must be provided.  For attach semantics, these are
92         ignored.
93
94         The first process that creates the SharedDict is responsible for
95         (optionally) naming it and deciding the max size (in bytes) that
96         it may be.  It does this via args to the c'tor.
97
98         Subsequent processes may safely omit name and size args.
99
100         """
101         assert size_bytes is None or size_bytes > 0
102         self._serializer = PickleSerializer()
103         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
104         self._ensure_memory_initialization()
105         self.name = self.shared_memory.name
106
107     def get_name(self):
108         """Returns the name of the shared memory buffer backing the dict."""
109         return self.name
110
111     def _get_or_create_memory_block(
112         self,
113         name: Optional[str] = None,
114         size_bytes: Optional[int] = None,
115     ) -> shared_memory.SharedMemory:
116         try:
117             return shared_memory.SharedMemory(name=name)
118         except FileNotFoundError:
119             assert size_bytes is not None
120             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
121
122     def _ensure_memory_initialization(self):
123         with SharedDict.LOCK:
124             memory_is_empty = (
125                 bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
126             )
127             if memory_is_empty:
128                 self.clear()
129
130     def _write_memory(self, db: Dict[Hashable, Any]) -> None:
131         data = self._serializer.dumps(db)
132         with SharedDict.LOCK:
133             try:
134                 self.shared_memory.buf[: len(data)] = data
135             except ValueError as e:
136                 raise ValueError("exceeds available storage") from e
137
138     def _read_memory(self) -> Dict[Hashable, Any]:
139         with SharedDict.LOCK:
140             return self._serializer.loads(self.shared_memory.buf.tobytes())
141
142     @contextmanager
143     def _modify_dict(self):
144         with SharedDict.LOCK:
145             db = self._read_memory()
146             yield db
147             self._write_memory(db)
148
149     def close(self) -> None:
150         """Unmap the shared dict and memory behind it from this
151         process.  Called by automatically __del__"""
152         if not hasattr(self, 'shared_memory'):
153             return
154         self.shared_memory.close()
155
156     def cleanup(self) -> None:
157         """Unlink the shared dict and memory behind it.  Only the last process should
158         invoke this.  Not called automatically."""
159         if not hasattr(self, 'shared_memory'):
160             return
161         with SharedDict.LOCK:
162             self.shared_memory.unlink()
163
164     def clear(self) -> None:
165         """Clear the dict."""
166         self._write_memory({})
167
168     def copy(self) -> Dict[Hashable, Any]:
169         """Returns a shallow copy of the dict."""
170         return self._read_memory()
171
172     def __getitem__(self, key: Hashable) -> Any:
173         return self._read_memory()[key]
174
175     def __setitem__(self, key: Hashable, value: Any) -> None:
176         with self._modify_dict() as db:
177             db[key] = value
178
179     def __len__(self) -> int:
180         return len(self._read_memory())
181
182     def __delitem__(self, key: Hashable) -> None:
183         with self._modify_dict() as db:
184             del db[key]
185
186     def __iter__(self) -> Iterator[Hashable]:
187         return iter(self._read_memory())
188
189     def __reversed__(self) -> Iterator[Hashable]:
190         return reversed(self._read_memory())
191
192     def __del__(self) -> None:
193         self.close()
194
195     def __contains__(self, key: Hashable) -> bool:
196         return key in self._read_memory()
197
198     def __eq__(self, other: Any) -> bool:
199         return self._read_memory() == other
200
201     def __ne__(self, other: Any) -> bool:
202         return self._read_memory() != other
203
204     def __str__(self) -> str:
205         return str(self._read_memory())
206
207     def __repr__(self) -> str:
208         return repr(self._read_memory())
209
210     def get(self, key: str, default: Optional[Any] = None) -> Any:
211         """Gets the value associated with key or a default."""
212         return self._read_memory().get(key, default)
213
214     def keys(self) -> KeysView[Hashable]:
215         return self._read_memory().keys()
216
217     def values(self) -> ValuesView[Any]:
218         return self._read_memory().values()
219
220     def items(self) -> ItemsView[Hashable, Any]:
221         return self._read_memory().items()
222
223     def popitem(self) -> Tuple[Hashable, Any]:
224         """Remove and return the last added item."""
225         with self._modify_dict() as db:
226             return db.popitem()
227
228     def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
229         """Remove and return the value associated with key or a default"""
230         with self._modify_dict() as db:
231             if default is None:
232                 return db.pop(key)
233             return db.pop(key, default)
234
235     def update(self, other=(), /, **kwds):
236         with self._modify_dict() as db:
237             db.update(other, **kwds)
238
239     def setdefault(self, key: Hashable, default: Optional[Any] = None):
240         with self._modify_dict() as db:
241             return db.setdefault(key, default)