Let's be explicit with asserts; there was a bug in histogram
[python_utils.git] / collect / shared_dict.py
1 #!/usr/bin/env python3
2
3 """
4 The MIT License (MIT)
5
6 Copyright (c) 2020 LuizaLabs
7 Additions Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on https://github.com/luizalabs/shared-memory-dict
28 """
29
30 import pickle
31 from contextlib import contextmanager
32 from functools import wraps
33 from multiprocessing import RLock, shared_memory
34 from typing import (
35     Any,
36     Dict,
37     Generator,
38     ItemsView,
39     Iterator,
40     KeysView,
41     Optional,
42     ValuesView,
43 )
44
45 from decorator_utils import synchronized
46
47
48 class PickleSerializer:
49     def dumps(self, obj: dict) -> bytes:
50         try:
51             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
52         except pickle.PicklingError as e:
53             raise Exception(e)
54
55     def loads(self, data: bytes) -> dict:
56         try:
57             return pickle.loads(data)
58         except pickle.UnpicklingError as e:
59             raise Exception(e)
60
61
62 # TODO: protobuf serializer?
63
64
65 class SharedDict(object):
66     NULL_BYTE = b'\x00'
67     MPLOCK = RLock()
68
69     def __init__(
70         self,
71         name: str,
72         size_bytes: Optional[int] = None,
73     ) -> None:
74         super().__init__()
75         self.name = name
76         self._serializer = PickleSerializer()
77         assert size_bytes is None or size_bytes > 0
78         self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
79         self._ensure_memory_initialization()
80         self.lock = RLock()
81
82     def get_name(self):
83         return self.name
84
85     def _get_or_create_memory_block(
86         self,
87         name: str,
88         size_bytes: Optional[int] = None,
89     ) -> shared_memory.SharedMemory:
90         try:
91             return shared_memory.SharedMemory(name=name)
92         except FileNotFoundError:
93             assert size_bytes is not None
94             return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
95
96     def _ensure_memory_initialization(self):
97         memory_is_empty = (
98             bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
99         )
100         if memory_is_empty:
101             self.clear()
102
103     def close(self) -> None:
104         if not hasattr(self, 'shared_memory'):
105             return
106         self.shared_memory.close()
107
108     def cleanup(self) -> None:
109         if not hasattr(self, 'shared_memory'):
110             return
111         self.shared_memory.unlink()
112
113     def clear(self) -> None:
114         self._save_memory({})
115
116     def popitem(self, last: Optional[bool] = None) -> Any:
117         with self._modify_db() as db:
118             return db.popitem()
119
120     @contextmanager
121     def _modify_db(self) -> Generator:
122         with SharedDict.MPLOCK:
123             db = self._read_memory()
124             yield db
125             self._save_memory(db)
126
127     def __getitem__(self, key: str) -> Any:
128         return self._read_memory()[key]
129
130     def __setitem__(self, key: str, value: Any) -> None:
131         with self._modify_db() as db:
132             db[key] = value
133
134     def __len__(self) -> int:
135         return len(self._read_memory())
136
137     def __delitem__(self, key: str) -> None:
138         with self._modify_db() as db:
139             del db[key]
140
141     def __iter__(self) -> Iterator:
142         return iter(self._read_memory())
143
144     def __reversed__(self):
145         return reversed(self._read_memory())
146
147     def __del__(self) -> None:
148         self.close()
149
150     def __contains__(self, key: str) -> bool:
151         return key in self._read_memory()
152
153     def __eq__(self, other: Any) -> bool:
154         return self._read_memory() == other
155
156     def __ne__(self, other: Any) -> bool:
157         return self._read_memory() != other
158
159     def __str__(self) -> str:
160         return str(self._read_memory())
161
162     def __repr__(self) -> str:
163         return repr(self._read_memory())
164
165     def get(self, key: str, default: Optional[Any] = None) -> Any:
166         return self._read_memory().get(key, default)
167
168     def keys(self) -> KeysView[Any]:
169         return self._read_memory().keys()
170
171     def values(self) -> ValuesView[Any]:
172         return self._read_memory().values()
173
174     def items(self) -> ItemsView:
175         return self._read_memory().items()
176
177     def pop(self, key: str, default: Optional[Any] = None):
178         with self._modify_db() as db:
179             if default is None:
180                 return db.pop(key)
181             return db.pop(key, default)
182
183     def update(self, other=(), /, **kwds):
184         with self._modify_db() as db:
185             db.update(other, **kwds)
186
187     def setdefault(self, key: str, default: Optional[Any] = None):
188         with self._modify_db() as db:
189             return db.setdefault(key, default)
190
191     def _save_memory(self, db: Dict[str, Any]) -> None:
192         with SharedDict.MPLOCK:
193             data = self._serializer.dumps(db)
194             try:
195                 self.shared_memory.buf[: len(data)] = data
196             except ValueError as exc:
197                 raise ValueError("exceeds available storage") from exc
198
199     def _read_memory(self) -> Dict[str, Any]:
200         with SharedDict.MPLOCK:
201             return self._serializer.loads(self.shared_memory.buf.tobytes())
202
203
204 if __name__ == '__main__':
205     import doctest
206
207     doctest.testmod()