Change locking boundaries for shared dict. Add a unit test.
[python_utils.git] / collect / shared_dict.py
1 #!/usr/bin/env python3
2
3 """
4 The MIT License (MIT)
5
6 Copyright (c) 2020 LuizaLabs
7 Additions Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on https://github.com/luizalabs/shared-memory-dict
28 """
29
30 import pickle
31 from contextlib import contextmanager
32 from functools import wraps
33 from multiprocessing import shared_memory, RLock
34 from typing import (
35     Any,
36     Dict,
37     Generator,
38     KeysView,
39     ItemsView,
40     Iterator,
41     Optional,
42     ValuesView,
43 )
44
45 from decorator_utils import synchronized
46
47
48 class PickleSerializer:
49     def dumps(self, obj: dict) -> bytes:
50         try:
51             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
52         except pickle.PicklingError as e:
53             raise Exception(e)
54
55     def loads(self, data: bytes) -> dict:
56         try:
57             return pickle.loads(data)
58         except pickle.UnpicklingError as e:
59             raise Exception(e)
60
61
62 # TODO: protobuf serializer?
63
64
65 class SharedDict(object):
66     NULL_BYTE = b'\x00'
67     MPLOCK = RLock()
68
69     def __init__(
70         self,
71         name: str,
72         size_bytes: Optional[int] = None,
73     ) -> None:
74         super().__init__()
75         self.name = name
76         self._serializer = PickleSerializer()
77         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
78         self._ensure_memory_initialization()
79         self.lock = RLock()
80
81     def get_name(self):
82         return self.name
83
84     def _get_or_create_memory_block(
85         self,
86         name: str,
87         size_bytes: Optional[int] = None,
88     ) -> shared_memory.SharedMemory:
89         try:
90             return shared_memory.SharedMemory(name=name)
91         except FileNotFoundError:
92             assert size_bytes
93             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
94
95     def _ensure_memory_initialization(self):
96         memory_is_empty = (
97             bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
98         )
99         if memory_is_empty:
100             self.clear()
101
102     def close(self) -> None:
103         if not hasattr(self, 'shared_memory'):
104             return
105         self.shared_memory.close()
106
107     def cleanup(self) -> None:
108         if not hasattr(self, 'shared_memory'):
109             return
110         self.shared_memory.unlink()
111
112     def clear(self) -> None:
113         self._save_memory({})
114
115     def popitem(self, last: Optional[bool] = None) -> Any:
116         with self._modify_db() as db:
117             return db.popitem()
118
119     @contextmanager
120     def _modify_db(self) -> Generator:
121         with SharedDict.MPLOCK:
122             db = self._read_memory()
123             yield db
124             self._save_memory(db)
125
126     def __getitem__(self, key: str) -> Any:
127         return self._read_memory()[key]
128
129     def __setitem__(self, key: str, value: Any) -> None:
130         with self._modify_db() as db:
131             db[key] = value
132
133     def __len__(self) -> int:
134         return len(self._read_memory())
135
136     def __delitem__(self, key: str) -> None:
137         with self._modify_db() as db:
138             del db[key]
139
140     def __iter__(self) -> Iterator:
141         return iter(self._read_memory())
142
143     def __reversed__(self):
144         return reversed(self._read_memory())
145
146     def __del__(self) -> None:
147         self.close()
148
149     def __contains__(self, key: str) -> bool:
150         return key in self._read_memory()
151
152     def __eq__(self, other: Any) -> bool:
153         return self._read_memory() == other
154
155     def __ne__(self, other: Any) -> bool:
156         return self._read_memory() != other
157
158     def __str__(self) -> str:
159         return str(self._read_memory())
160
161     def __repr__(self) -> str:
162         return repr(self._read_memory())
163
164     def get(self, key: str, default: Optional[Any] = None) -> Any:
165         return self._read_memory().get(key, default)
166
167     def keys(self) -> KeysView[Any]:
168         return self._read_memory().keys()
169
170     def values(self) -> ValuesView[Any]:
171         return self._read_memory().values()
172
173     def items(self) -> ItemsView:
174         return self._read_memory().items()
175
176     def pop(self, key: str, default: Optional[Any] = None):
177         with self._modify_db() as db:
178             if default is None:
179                 return db.pop(key)
180             return db.pop(key, default)
181
182     def update(self, other=(), /, **kwds):
183         with self._modify_db() as db:
184             db.update(other, **kwds)
185
186     def setdefault(self, key: str, default: Optional[Any] = None):
187         with self._modify_db() as db:
188             return db.setdefault(key, default)
189
190     def _save_memory(self, db: Dict[str, Any]) -> None:
191         with SharedDict.MPLOCK:
192             data = self._serializer.dumps(db)
193             try:
194                 self.shared_memory.buf[: len(data)] = data
195             except ValueError as exc:
196                 raise ValueError("exceeds available storage") from exc
197
198     def _read_memory(self) -> Dict[str, Any]:
199         with SharedDict.MPLOCK:
200             return self._serializer.loads(self.shared_memory.buf.tobytes())
201
202
203 if __name__ == '__main__':
204     import doctest
205
206     doctest.testmod()