Add MIT License and copyright from the shared dict implementation
[python_utils.git] / collect / shared_dict.py
1 #!/usr/bin/env python3
2
3 """
4 The MIT License (MIT)
5
6 Copyright (c) 2020 LuizaLabs
7 Additions Copyright (c) 2022 Scott Gasch
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in all
17 copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26
27 This class is based on https://github.com/luizalabs/shared-memory-dict
28 """
29
30 import pickle
31 from contextlib import contextmanager
32 from functools import wraps
33 from multiprocessing import shared_memory, Lock
34 from typing import (
35     Any,
36     Dict,
37     Generator,
38     KeysView,
39     ItemsView,
40     Iterator,
41     Optional,
42     ValuesView,
43 )
44
45 from decorator_utils import synchronized
46
47
48 class PickleSerializer:
49     def dumps(self, obj: dict) -> bytes:
50         try:
51             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
52         except pickle.PicklingError as e:
53             raise Exception(e)
54
55     def loads(self, data: bytes) -> dict:
56         try:
57             return pickle.loads(data)
58         except pickle.UnpicklingError as e:
59             raise Exception(e)
60
61
62 # TODO: protobuf serializer?
63
64
65 class SharedDict(object):
66     NULL_BYTE = b'\x00'
67     MPLOCK = Lock()
68
69     def __init__(
70         self,
71         name: str,
72         size: int,
73     ) -> None:
74         super().__init__()
75         self._serializer = PickleSerializer()
76         self.shared_memory = self._get_or_create_memory_block(name, size)
77         self._ensure_memory_initialization()
78
79     def _get_or_create_memory_block(
80         self, name: str, size: int
81     ) -> shared_memory.SharedMemory:
82         try:
83             return shared_memory.SharedMemory(name=name)
84         except FileNotFoundError:
85             return shared_memory.SharedMemory(name=name, create=True, size=size)
86
87     def _ensure_memory_initialization(self):
88         memory_is_empty = (
89             bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
90         )
91         if memory_is_empty:
92             self.clear()
93
94     def close(self) -> None:
95         if not hasattr(self, 'shared_memory'):
96             return
97         self.shared_memory.close()
98
99     def cleanup(self) -> None:
100         if not hasattr(self, 'shared_memory'):
101             return
102         self.shared_memory.unlink()
103
104     @synchronized(MPLOCK)
105     def clear(self) -> None:
106         self._save_memory({})
107
108     def popitem(self, last: Optional[bool] = None) -> Any:
109         with self._modify_db() as db:
110             return db.popitem()
111
112     @synchronized(MPLOCK)
113     @contextmanager
114     def _modify_db(self) -> Generator:
115         db = self._read_memory()
116         yield db
117         self._save_memory(db)
118
119     def __getitem__(self, key: str) -> Any:
120         return self._read_memory()[key]
121
122     def __setitem__(self, key: str, value: Any) -> None:
123         with self._modify_db() as db:
124             db[key] = value
125
126     def __len__(self) -> int:
127         return len(self._read_memory())
128
129     def __delitem__(self, key: str) -> None:
130         with self._modify_db() as db:
131             del db[key]
132
133     def __iter__(self) -> Iterator:
134         return iter(self._read_memory())
135
136     def __reversed__(self):
137         return reversed(self._read_memory())
138
139     def __del__(self) -> None:
140         self.close()
141
142     def __contains__(self, key: str) -> bool:
143         return key in self._read_memory()
144
145     def __eq__(self, other: Any) -> bool:
146         return self._read_memory() == other
147
148     def __ne__(self, other: Any) -> bool:
149         return self._read_memory() != other
150
151     def __str__(self) -> str:
152         return str(self._read_memory())
153
154     def __repr__(self) -> str:
155         return repr(self._read_memory())
156
157     def get(self, key: str, default: Optional[Any] = None) -> Any:
158         return self._read_memory().get(key, default)
159
160     def keys(self) -> KeysView[Any]:
161         return self._read_memory().keys()
162
163     def values(self) -> ValuesView[Any]:
164         return self._read_memory().values()
165
166     def items(self) -> ItemsView:
167         return self._read_memory().items()
168
169     def pop(self, key: str, default: Optional[Any] = None):
170         with self._modify_db() as db:
171             if default is None:
172                 return db.pop(key)
173             return db.pop(key, default)
174
175     def update(self, other=(), /, **kwds):
176         with self._modify_db() as db:
177             db.update(other, **kwds)
178
179     def setdefault(self, key: str, default: Optional[Any] = None):
180         with self._modify_db() as db:
181             return db.setdefault(key, default)
182
183     def _save_memory(self, db: Dict[str, Any]) -> None:
184         data = self._serializer.dumps(db)
185         try:
186             self.shared_memory.buf[: len(data)] = data
187         except ValueError as exc:
188             raise ValueError("exceeds available storage") from exc
189
190     def _read_memory(self) -> Dict[str, Any]:
191         return self._serializer.loads(self.shared_memory.buf.tobytes())
192
193
194 if __name__ == '__main__':
195     import doctest
196
197     doctest.testmod()