Reduce the doctest lease duration...
[python_utils.git] / collect / shared_dict.py
1 #!/usr/bin/env python3
2
3 """
4 The MIT License (MIT)
5
6 Copyright (c) 2020 LuizaLabs
7 Additions/Modifications Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on https://github.com/luizalabs/shared-memory-dict
28 """
29
30 import pickle
31 from contextlib import contextmanager
32 from multiprocessing import RLock, shared_memory
33 from typing import (
34     Any,
35     Dict,
36     Generator,
37     ItemsView,
38     Iterator,
39     KeysView,
40     Optional,
41     ValuesView,
42 )
43
44
45 class PickleSerializer:
46     def dumps(self, obj: dict) -> bytes:
47         try:
48             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
49         except pickle.PicklingError as e:
50             raise Exception(e)
51
52     def loads(self, data: bytes) -> dict:
53         try:
54             return pickle.loads(data)
55         except pickle.UnpicklingError as e:
56             raise Exception(e)
57
58
59 # TODO: protobuf serializer?
60
61
62 class SharedDict(object):
63     NULL_BYTE = b'\x00'
64     MPLOCK = RLock()
65
66     def __init__(
67         self,
68         name: str,
69         size_bytes: Optional[int] = None,
70     ) -> None:
71         super().__init__()
72         self.name = name
73         self._serializer = PickleSerializer()
74         assert size_bytes is None or size_bytes > 0
75         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
76         self._ensure_memory_initialization()
77         self.lock = RLock()
78
79     def get_name(self):
80         return self.name
81
82     def _get_or_create_memory_block(
83         self,
84         name: str,
85         size_bytes: Optional[int] = None,
86     ) -> shared_memory.SharedMemory:
87         try:
88             return shared_memory.SharedMemory(name=name)
89         except FileNotFoundError:
90             assert size_bytes is not None
91             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
92
93     def _ensure_memory_initialization(self):
94         memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
95         if memory_is_empty:
96             self.clear()
97
98     def close(self) -> None:
99         if not hasattr(self, 'shared_memory'):
100             return
101         self.shared_memory.close()
102
103     def cleanup(self) -> None:
104         if not hasattr(self, 'shared_memory'):
105             return
106         with SharedDict.MPLOCK:
107             self.shared_memory.unlink()
108
109     def clear(self) -> None:
110         self._save_memory({})
111
112     def popitem(self, last: Optional[bool] = None) -> Any:
113         with self._modify_db() as db:
114             return db.popitem()
115
116     @contextmanager
117     def _modify_db(self) -> Generator:
118         with SharedDict.MPLOCK:
119             db = self._read_memory()
120             yield db
121             self._save_memory(db)
122
123     def __getitem__(self, key: str) -> Any:
124         return self._read_memory()[key]
125
126     def __setitem__(self, key: str, value: Any) -> None:
127         with self._modify_db() as db:
128             db[key] = value
129
130     def __len__(self) -> int:
131         return len(self._read_memory())
132
133     def __delitem__(self, key: str) -> None:
134         with self._modify_db() as db:
135             del db[key]
136
137     def __iter__(self) -> Iterator:
138         return iter(self._read_memory())
139
140     def __reversed__(self):
141         return reversed(self._read_memory())
142
143     def __del__(self) -> None:
144         self.close()
145
146     def __contains__(self, key: str) -> bool:
147         return key in self._read_memory()
148
149     def __eq__(self, other: Any) -> bool:
150         return self._read_memory() == other
151
152     def __ne__(self, other: Any) -> bool:
153         return self._read_memory() != other
154
155     def __str__(self) -> str:
156         return str(self._read_memory())
157
158     def __repr__(self) -> str:
159         return repr(self._read_memory())
160
161     def get(self, key: str, default: Optional[Any] = None) -> Any:
162         return self._read_memory().get(key, default)
163
164     def keys(self) -> KeysView[Any]:
165         return self._read_memory().keys()
166
167     def values(self) -> ValuesView[Any]:
168         return self._read_memory().values()
169
170     def items(self) -> ItemsView:
171         return self._read_memory().items()
172
173     def pop(self, key: str, default: Optional[Any] = None):
174         with self._modify_db() as db:
175             if default is None:
176                 return db.pop(key)
177             return db.pop(key, default)
178
179     def update(self, other=(), /, **kwds):
180         with self._modify_db() as db:
181             db.update(other, **kwds)
182
183     def setdefault(self, key: str, default: Optional[Any] = None):
184         with self._modify_db() as db:
185             return db.setdefault(key, default)
186
187     def _save_memory(self, db: Dict[str, Any]) -> None:
188         with SharedDict.MPLOCK:
189             data = self._serializer.dumps(db)
190             try:
191                 self.shared_memory.buf[: len(data)] = data
192             except ValueError as exc:
193                 raise ValueError("exceeds available storage") from exc
194
195     def _read_memory(self) -> Dict[str, Any]:
196         with SharedDict.MPLOCK:
197             return self._serializer.loads(self.shared_memory.buf.tobytes())
198
199
200 if __name__ == '__main__':
201     import doctest
202
203     doctest.testmod()