Take the lock before unlinking the mmap'ed shared memory to ensure
[python_utils.git] / collect / shared_dict.py
1 #!/usr/bin/env python3
2
3 """
4 The MIT License (MIT)
5
6 Copyright (c) 2020 LuizaLabs
7 Additions Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on https://github.com/luizalabs/shared-memory-dict
28 """
29
30 import pickle
31 from contextlib import contextmanager
32 from functools import wraps
33 from multiprocessing import RLock, shared_memory
34 from typing import (
35     Any,
36     Dict,
37     Generator,
38     ItemsView,
39     Iterator,
40     KeysView,
41     Optional,
42     ValuesView,
43 )
44
45 from decorator_utils import synchronized
46
47
48 class PickleSerializer:
49     def dumps(self, obj: dict) -> bytes:
50         try:
51             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
52         except pickle.PicklingError as e:
53             raise Exception(e)
54
55     def loads(self, data: bytes) -> dict:
56         try:
57             return pickle.loads(data)
58         except pickle.UnpicklingError as e:
59             raise Exception(e)
60
61
62 # TODO: protobuf serializer?
63
64
65 class SharedDict(object):
66     NULL_BYTE = b'\x00'
67     MPLOCK = RLock()
68
69     def __init__(
70         self,
71         name: str,
72         size_bytes: Optional[int] = None,
73     ) -> None:
74         super().__init__()
75         self.name = name
76         self._serializer = PickleSerializer()
77         assert size_bytes is None or size_bytes > 0
78         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
79         self._ensure_memory_initialization()
80         self.lock = RLock()
81
82     def get_name(self):
83         return self.name
84
85     def _get_or_create_memory_block(
86         self,
87         name: str,
88         size_bytes: Optional[int] = None,
89     ) -> shared_memory.SharedMemory:
90         try:
91             return shared_memory.SharedMemory(name=name)
92         except FileNotFoundError:
93             assert size_bytes is not None
94             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
95
96     def _ensure_memory_initialization(self):
97         memory_is_empty = (
98             bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
99         )
100         if memory_is_empty:
101             self.clear()
102
103     def close(self) -> None:
104         if not hasattr(self, 'shared_memory'):
105             return
106         self.shared_memory.close()
107
108     def cleanup(self) -> None:
109         if not hasattr(self, 'shared_memory'):
110             return
111         with SharedDict.MPLOCK:
112             self.shared_memory.unlink()
113
114     def clear(self) -> None:
115         self._save_memory({})
116
117     def popitem(self, last: Optional[bool] = None) -> Any:
118         with self._modify_db() as db:
119             return db.popitem()
120
121     @contextmanager
122     def _modify_db(self) -> Generator:
123         with SharedDict.MPLOCK:
124             db = self._read_memory()
125             yield db
126             self._save_memory(db)
127
128     def __getitem__(self, key: str) -> Any:
129         return self._read_memory()[key]
130
131     def __setitem__(self, key: str, value: Any) -> None:
132         with self._modify_db() as db:
133             db[key] = value
134
135     def __len__(self) -> int:
136         return len(self._read_memory())
137
138     def __delitem__(self, key: str) -> None:
139         with self._modify_db() as db:
140             del db[key]
141
142     def __iter__(self) -> Iterator:
143         return iter(self._read_memory())
144
145     def __reversed__(self):
146         return reversed(self._read_memory())
147
148     def __del__(self) -> None:
149         self.close()
150
151     def __contains__(self, key: str) -> bool:
152         return key in self._read_memory()
153
154     def __eq__(self, other: Any) -> bool:
155         return self._read_memory() == other
156
157     def __ne__(self, other: Any) -> bool:
158         return self._read_memory() != other
159
160     def __str__(self) -> str:
161         return str(self._read_memory())
162
163     def __repr__(self) -> str:
164         return repr(self._read_memory())
165
166     def get(self, key: str, default: Optional[Any] = None) -> Any:
167         return self._read_memory().get(key, default)
168
169     def keys(self) -> KeysView[Any]:
170         return self._read_memory().keys()
171
172     def values(self) -> ValuesView[Any]:
173         return self._read_memory().values()
174
175     def items(self) -> ItemsView:
176         return self._read_memory().items()
177
178     def pop(self, key: str, default: Optional[Any] = None):
179         with self._modify_db() as db:
180             if default is None:
181                 return db.pop(key)
182             return db.pop(key, default)
183
184     def update(self, other=(), /, **kwds):
185         with self._modify_db() as db:
186             db.update(other, **kwds)
187
188     def setdefault(self, key: str, default: Optional[Any] = None):
189         with self._modify_db() as db:
190             return db.setdefault(key, default)
191
192     def _save_memory(self, db: Dict[str, Any]) -> None:
193         with SharedDict.MPLOCK:
194             data = self._serializer.dumps(db)
195             try:
196                 self.shared_memory.buf[: len(data)] = data
197             except ValueError as exc:
198                 raise ValueError("exceeds available storage") from exc
199
200     def _read_memory(self) -> Dict[str, Any]:
201         with SharedDict.MPLOCK:
202             return self._serializer.loads(self.shared_memory.buf.tobytes())
203
204
205 if __name__ == '__main__':
206     import doctest
207
208     doctest.testmod()