Change settings in flake8 and black.
[python_utils.git] / collect / shared_dict.py
1 #!/usr/bin/env python3
2
3 """
4 The MIT License (MIT)
5
6 Copyright (c) 2020 LuizaLabs
7 Additions Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on https://github.com/luizalabs/shared-memory-dict
28 """
29
30 import pickle
31 from contextlib import contextmanager
32 from functools import wraps
33 from multiprocessing import RLock, shared_memory
34 from typing import Any, Dict, Generator, ItemsView, Iterator, KeysView, Optional, ValuesView
35
36 from decorator_utils import synchronized
37
38
39 class PickleSerializer:
40     def dumps(self, obj: dict) -> bytes:
41         try:
42             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
43         except pickle.PicklingError as e:
44             raise Exception(e)
45
46     def loads(self, data: bytes) -> dict:
47         try:
48             return pickle.loads(data)
49         except pickle.UnpicklingError as e:
50             raise Exception(e)
51
52
53 # TODO: protobuf serializer?
54
55
56 class SharedDict(object):
57     NULL_BYTE = b'\x00'
58     MPLOCK = RLock()
59
60     def __init__(
61         self,
62         name: str,
63         size_bytes: Optional[int] = None,
64     ) -> None:
65         super().__init__()
66         self.name = name
67         self._serializer = PickleSerializer()
68         assert size_bytes is None or size_bytes > 0
69         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
70         self._ensure_memory_initialization()
71         self.lock = RLock()
72
73     def get_name(self):
74         return self.name
75
76     def _get_or_create_memory_block(
77         self,
78         name: str,
79         size_bytes: Optional[int] = None,
80     ) -> shared_memory.SharedMemory:
81         try:
82             return shared_memory.SharedMemory(name=name)
83         except FileNotFoundError:
84             assert size_bytes is not None
85             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
86
87     def _ensure_memory_initialization(self):
88         memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
89         if memory_is_empty:
90             self.clear()
91
92     def close(self) -> None:
93         if not hasattr(self, 'shared_memory'):
94             return
95         self.shared_memory.close()
96
97     def cleanup(self) -> None:
98         if not hasattr(self, 'shared_memory'):
99             return
100         with SharedDict.MPLOCK:
101             self.shared_memory.unlink()
102
103     def clear(self) -> None:
104         self._save_memory({})
105
106     def popitem(self, last: Optional[bool] = None) -> Any:
107         with self._modify_db() as db:
108             return db.popitem()
109
110     @contextmanager
111     def _modify_db(self) -> Generator:
112         with SharedDict.MPLOCK:
113             db = self._read_memory()
114             yield db
115             self._save_memory(db)
116
117     def __getitem__(self, key: str) -> Any:
118         return self._read_memory()[key]
119
120     def __setitem__(self, key: str, value: Any) -> None:
121         with self._modify_db() as db:
122             db[key] = value
123
124     def __len__(self) -> int:
125         return len(self._read_memory())
126
127     def __delitem__(self, key: str) -> None:
128         with self._modify_db() as db:
129             del db[key]
130
131     def __iter__(self) -> Iterator:
132         return iter(self._read_memory())
133
134     def __reversed__(self):
135         return reversed(self._read_memory())
136
137     def __del__(self) -> None:
138         self.close()
139
140     def __contains__(self, key: str) -> bool:
141         return key in self._read_memory()
142
143     def __eq__(self, other: Any) -> bool:
144         return self._read_memory() == other
145
146     def __ne__(self, other: Any) -> bool:
147         return self._read_memory() != other
148
149     def __str__(self) -> str:
150         return str(self._read_memory())
151
152     def __repr__(self) -> str:
153         return repr(self._read_memory())
154
155     def get(self, key: str, default: Optional[Any] = None) -> Any:
156         return self._read_memory().get(key, default)
157
158     def keys(self) -> KeysView[Any]:
159         return self._read_memory().keys()
160
161     def values(self) -> ValuesView[Any]:
162         return self._read_memory().values()
163
164     def items(self) -> ItemsView:
165         return self._read_memory().items()
166
167     def pop(self, key: str, default: Optional[Any] = None):
168         with self._modify_db() as db:
169             if default is None:
170                 return db.pop(key)
171             return db.pop(key, default)
172
173     def update(self, other=(), /, **kwds):
174         with self._modify_db() as db:
175             db.update(other, **kwds)
176
177     def setdefault(self, key: str, default: Optional[Any] = None):
178         with self._modify_db() as db:
179             return db.setdefault(key, default)
180
181     def _save_memory(self, db: Dict[str, Any]) -> None:
182         with SharedDict.MPLOCK:
183             data = self._serializer.dumps(db)
184             try:
185                 self.shared_memory.buf[: len(data)] = data
186             except ValueError as exc:
187                 raise ValueError("exceeds available storage") from exc
188
189     def _read_memory(self) -> Dict[str, Any]:
190         with SharedDict.MPLOCK:
191             return self._serializer.loads(self.shared_memory.buf.tobytes())
192
193
194 if __name__ == '__main__':
195     import doctest
196
197     doctest.testmod()