Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / collect / shared_dict.py
1 #!/usr/bin/env python3
2
3 """
4 The MIT License (MIT)
5
6 Copyright (c) 2020 LuizaLabs
7 Additions/Modifications Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on https://github.com/luizalabs/shared-memory-dict
28 """
29
30 import pickle
31 from contextlib import contextmanager
32 from multiprocessing import RLock, shared_memory
33 from typing import (
34     Any,
35     Dict,
36     Hashable,
37     ItemsView,
38     Iterator,
39     KeysView,
40     Optional,
41     Tuple,
42     ValuesView,
43 )
44
45
46 class PickleSerializer:
47     """A serializer that uses pickling.  Used to read/write bytes in the shared
48     memory region and interpret them as a dict."""
49
50     def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
51         try:
52             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
53         except pickle.PicklingError as e:
54             raise Exception from e
55
56     def loads(self, data: bytes) -> Dict[Hashable, Any]:
57         try:
58             return pickle.loads(data)
59         except pickle.UnpicklingError as e:
60             raise Exception from e
61
62
63 # TODOs: profile the serializers and figure out the fastest one.  Can
64 # we use a ChainMap to avoid the constant de/re-serialization of the
65 # whole thing?
66
67
68 class SharedDict(object):
69     """This class emulates the dict container but uses a
70     Multiprocessing.SharedMemory region to back the dict such that it
71     can be read and written by multiple independent processes at the
72     same time.  Because it constantly de/re-serializes the dict, it is
73     much slower than a normal dict.
74
75     """
76
77     NULL_BYTE = b'\x00'
78     LOCK = RLock()
79
80     def __init__(
81         self,
82         name: Optional[str] = None,
83         size_bytes: Optional[int] = None,
84     ) -> None:
85         """
86         Creates or attaches a shared dictionary back by a SharedMemory buffer.
87         For create semantics, a unique name (string) and a max dictionary size
88         (expressed in bytes) must be provided.  For attach semantics, these are
89         ignored.
90
91         The first process that creates the SharedDict is responsible for
92         (optionally) naming it and deciding the max size (in bytes) that
93         it may be.  It does this via args to the c'tor.
94
95         Subsequent processes may safely omit name and size args.
96
97         """
98         assert size_bytes is None or size_bytes > 0
99         self._serializer = PickleSerializer()
100         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
101         self._ensure_memory_initialization()
102         self.name = self.shared_memory.name
103
104     def get_name(self):
105         """Returns the name of the shared memory buffer backing the dict."""
106         return self.name
107
108     def _get_or_create_memory_block(
109         self,
110         name: Optional[str] = None,
111         size_bytes: Optional[int] = None,
112     ) -> shared_memory.SharedMemory:
113         try:
114             return shared_memory.SharedMemory(name=name)
115         except FileNotFoundError:
116             assert size_bytes is not None
117             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
118
119     def _ensure_memory_initialization(self):
120         with SharedDict.LOCK:
121             memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
122             if memory_is_empty:
123                 self.clear()
124
125     def _write_memory(self, db: Dict[Hashable, Any]) -> None:
126         data = self._serializer.dumps(db)
127         with SharedDict.LOCK:
128             try:
129                 self.shared_memory.buf[: len(data)] = data
130             except ValueError as e:
131                 raise ValueError("exceeds available storage") from e
132
133     def _read_memory(self) -> Dict[Hashable, Any]:
134         with SharedDict.LOCK:
135             return self._serializer.loads(self.shared_memory.buf.tobytes())
136
137     @contextmanager
138     def _modify_dict(self):
139         with SharedDict.LOCK:
140             db = self._read_memory()
141             yield db
142             self._write_memory(db)
143
144     def close(self) -> None:
145         """Unmap the shared dict and memory behind it from this
146         process.  Called by automatically __del__"""
147         if not hasattr(self, 'shared_memory'):
148             return
149         self.shared_memory.close()
150
151     def cleanup(self) -> None:
152         """Unlink the shared dict and memory behind it.  Only the last process should
153         invoke this.  Not called automatically."""
154         if not hasattr(self, 'shared_memory'):
155             return
156         with SharedDict.LOCK:
157             self.shared_memory.unlink()
158
159     def clear(self) -> None:
160         """Clear the dict."""
161         self._write_memory({})
162
163     def copy(self) -> Dict[Hashable, Any]:
164         """Returns a shallow copy of the dict."""
165         return self._read_memory()
166
167     def __getitem__(self, key: Hashable) -> Any:
168         return self._read_memory()[key]
169
170     def __setitem__(self, key: Hashable, value: Any) -> None:
171         with self._modify_dict() as db:
172             db[key] = value
173
174     def __len__(self) -> int:
175         return len(self._read_memory())
176
177     def __delitem__(self, key: Hashable) -> None:
178         with self._modify_dict() as db:
179             del db[key]
180
181     def __iter__(self) -> Iterator[Hashable]:
182         return iter(self._read_memory())
183
184     def __reversed__(self) -> Iterator[Hashable]:
185         return reversed(self._read_memory())
186
187     def __del__(self) -> None:
188         self.close()
189
190     def __contains__(self, key: Hashable) -> bool:
191         return key in self._read_memory()
192
193     def __eq__(self, other: Any) -> bool:
194         return self._read_memory() == other
195
196     def __ne__(self, other: Any) -> bool:
197         return self._read_memory() != other
198
199     def __str__(self) -> str:
200         return str(self._read_memory())
201
202     def __repr__(self) -> str:
203         return repr(self._read_memory())
204
205     def get(self, key: str, default: Optional[Any] = None) -> Any:
206         """Gets the value associated with key or a default."""
207         return self._read_memory().get(key, default)
208
209     def keys(self) -> KeysView[Hashable]:
210         return self._read_memory().keys()
211
212     def values(self) -> ValuesView[Any]:
213         return self._read_memory().values()
214
215     def items(self) -> ItemsView[Hashable, Any]:
216         return self._read_memory().items()
217
218     def popitem(self) -> Tuple[Hashable, Any]:
219         """Remove and return the last added item."""
220         with self._modify_dict() as db:
221             return db.popitem()
222
223     def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
224         """Remove and return the value associated with key or a default"""
225         with self._modify_dict() as db:
226             if default is None:
227                 return db.pop(key)
228             return db.pop(key, default)
229
230     def update(self, other=(), /, **kwds):
231         with self._modify_dict() as db:
232             db.update(other, **kwds)
233
234     def setdefault(self, key: Hashable, default: Optional[Any] = None):
235         with self._modify_dict() as db:
236             return db.setdefault(key, default)
237
238
239 if __name__ == '__main__':
240     import doctest
241
242     doctest.testmod()