Migration from old pyutilz package name (which, in turn, came from
[pyutils.git] / src / pyutils / collectionz / shared_dict.py
1 #!/usr/bin/env python3
2
3 """
4 The MIT License (MIT)
5
6 Copyright (c) 2020 LuizaLabs
7 Additions/Modifications Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on https://github.com/luizalabs/shared-memory-dict.
28 For details about what is preserved from the original and what was changed
29 by Scott, see NOTICE at the root of this module.
30 """
31
32 import pickle
33 from contextlib import contextmanager
34 from multiprocessing import RLock, shared_memory
35 from typing import (
36     Any,
37     Dict,
38     Hashable,
39     ItemsView,
40     Iterator,
41     KeysView,
42     Optional,
43     Tuple,
44     ValuesView,
45 )
46
47
48 class PickleSerializer:
49     """A serializer that uses pickling.  Used to read/write bytes in the shared
50     memory region and interpret them as a dict."""
51
52     def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
53         try:
54             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
55         except pickle.PicklingError as e:
56             raise Exception from e
57
58     def loads(self, data: bytes) -> Dict[Hashable, Any]:
59         try:
60             return pickle.loads(data)
61         except pickle.UnpicklingError as e:
62             raise Exception from e
63
64
65 # TODOs: profile the serializers and figure out the fastest one.  Can
66 # we use a ChainMap to avoid the constant de/re-serialization of the
67 # whole thing?
68
69
70 class SharedDict(object):
71     """This class emulates the dict container but uses a
72     Multiprocessing.SharedMemory region to back the dict such that it
73     can be read and written by multiple independent processes at the
74     same time.  Because it constantly de/re-serializes the dict, it is
75     much slower than a normal dict.
76
77     """
78
79     NULL_BYTE = b'\x00'
80     LOCK = RLock()
81
82     def __init__(
83         self,
84         name: Optional[str] = None,
85         size_bytes: Optional[int] = None,
86     ) -> None:
87         """
88         Creates or attaches a shared dictionary back by a SharedMemory buffer.
89         For create semantics, a unique name (string) and a max dictionary size
90         (expressed in bytes) must be provided.  For attach semantics, these are
91         ignored.
92
93         The first process that creates the SharedDict is responsible for
94         (optionally) naming it and deciding the max size (in bytes) that
95         it may be.  It does this via args to the c'tor.
96
97         Subsequent processes may safely omit name and size args.
98
99         """
100         assert size_bytes is None or size_bytes > 0
101         self._serializer = PickleSerializer()
102         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
103         self._ensure_memory_initialization()
104         self.name = self.shared_memory.name
105
106     def get_name(self):
107         """Returns the name of the shared memory buffer backing the dict."""
108         return self.name
109
110     def _get_or_create_memory_block(
111         self,
112         name: Optional[str] = None,
113         size_bytes: Optional[int] = None,
114     ) -> shared_memory.SharedMemory:
115         try:
116             return shared_memory.SharedMemory(name=name)
117         except FileNotFoundError:
118             assert size_bytes is not None
119             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
120
121     def _ensure_memory_initialization(self):
122         with SharedDict.LOCK:
123             memory_is_empty = (
124                 bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
125             )
126             if memory_is_empty:
127                 self.clear()
128
129     def _write_memory(self, db: Dict[Hashable, Any]) -> None:
130         data = self._serializer.dumps(db)
131         with SharedDict.LOCK:
132             try:
133                 self.shared_memory.buf[: len(data)] = data
134             except ValueError as e:
135                 raise ValueError("exceeds available storage") from e
136
137     def _read_memory(self) -> Dict[Hashable, Any]:
138         with SharedDict.LOCK:
139             return self._serializer.loads(self.shared_memory.buf.tobytes())
140
141     @contextmanager
142     def _modify_dict(self):
143         with SharedDict.LOCK:
144             db = self._read_memory()
145             yield db
146             self._write_memory(db)
147
148     def close(self) -> None:
149         """Unmap the shared dict and memory behind it from this
150         process.  Called by automatically __del__"""
151         if not hasattr(self, 'shared_memory'):
152             return
153         self.shared_memory.close()
154
155     def cleanup(self) -> None:
156         """Unlink the shared dict and memory behind it.  Only the last process should
157         invoke this.  Not called automatically."""
158         if not hasattr(self, 'shared_memory'):
159             return
160         with SharedDict.LOCK:
161             self.shared_memory.unlink()
162
163     def clear(self) -> None:
164         """Clear the dict."""
165         self._write_memory({})
166
167     def copy(self) -> Dict[Hashable, Any]:
168         """Returns a shallow copy of the dict."""
169         return self._read_memory()
170
171     def __getitem__(self, key: Hashable) -> Any:
172         return self._read_memory()[key]
173
174     def __setitem__(self, key: Hashable, value: Any) -> None:
175         with self._modify_dict() as db:
176             db[key] = value
177
178     def __len__(self) -> int:
179         return len(self._read_memory())
180
181     def __delitem__(self, key: Hashable) -> None:
182         with self._modify_dict() as db:
183             del db[key]
184
185     def __iter__(self) -> Iterator[Hashable]:
186         return iter(self._read_memory())
187
188     def __reversed__(self) -> Iterator[Hashable]:
189         return reversed(self._read_memory())
190
191     def __del__(self) -> None:
192         self.close()
193
194     def __contains__(self, key: Hashable) -> bool:
195         return key in self._read_memory()
196
197     def __eq__(self, other: Any) -> bool:
198         return self._read_memory() == other
199
200     def __ne__(self, other: Any) -> bool:
201         return self._read_memory() != other
202
203     def __str__(self) -> str:
204         return str(self._read_memory())
205
206     def __repr__(self) -> str:
207         return repr(self._read_memory())
208
209     def get(self, key: str, default: Optional[Any] = None) -> Any:
210         """Gets the value associated with key or a default."""
211         return self._read_memory().get(key, default)
212
213     def keys(self) -> KeysView[Hashable]:
214         return self._read_memory().keys()
215
216     def values(self) -> ValuesView[Any]:
217         return self._read_memory().values()
218
219     def items(self) -> ItemsView[Hashable, Any]:
220         return self._read_memory().items()
221
222     def popitem(self) -> Tuple[Hashable, Any]:
223         """Remove and return the last added item."""
224         with self._modify_dict() as db:
225             return db.popitem()
226
227     def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
228         """Remove and return the value associated with key or a default"""
229         with self._modify_dict() as db:
230             if default is None:
231                 return db.pop(key)
232             return db.pop(key, default)
233
234     def update(self, other=(), /, **kwds):
235         with self._modify_dict() as db:
236             db.update(other, **kwds)
237
238     def setdefault(self, key: Hashable, default: Optional[Any] = None):
239         with self._modify_dict() as db:
240             return db.setdefault(key, default)