Add a MP shared dict.
[python_utils.git] / collect / shared_dict.py
1 #!/usr/bin/env python3
2
3 import pickle
4 from contextlib import contextmanager
5 from functools import wraps
6 from multiprocessing import shared_memory, Lock
7 from typing import (
8     Any,
9     Dict,
10     Generator,
11     KeysView,
12     ItemsView,
13     Iterator,
14     Optional,
15     ValuesView,
16 )
17
18 from decorator_utils import synchronized
19
20
21 class PickleSerializer:
22     def dumps(self, obj: dict) -> bytes:
23         try:
24             return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
25         except pickle.PicklingError as e:
26             raise Exception(e)
27
28     def loads(self, data: bytes) -> dict:
29         try:
30             return pickle.loads(data)
31         except pickle.UnpicklingError as e:
32             raise Exception(e)
33
34
35 # TODO: protobuf serializer?
36
37
38 class SharedDict(object):
39     NULL_BYTE = b'\x00'
40     MPLOCK = Lock()
41
42     def __init__(
43         self,
44         name: str,
45         size: int,
46     ) -> None:
47         super().__init__()
48         self._serializer = PickleSerializer()
49         self.shared_memory = self._get_or_create_memory_block(name, size)
50         self._ensure_memory_initialization()
51
52     def _get_or_create_memory_block(
53         self, name: str, size: int
54     ) -> shared_memory.SharedMemory:
55         try:
56             return shared_memory.SharedMemory(name=name)
57         except FileNotFoundError:
58             return shared_memory.SharedMemory(name=name, create=True, size=size)
59
60     def _ensure_memory_initialization(self):
61         memory_is_empty = (
62             bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
63         )
64         if memory_is_empty:
65             self.clear()
66
67     def close(self) -> None:
68         if not hasattr(self, 'shared_memory'):
69             return
70         self.shared_memory.close()
71
72     def cleanup(self) -> None:
73         if not hasattr(self, 'shared_memory'):
74             return
75         self.shared_memory.unlink()
76
77     @synchronized(MPLOCK)
78     def clear(self) -> None:
79         self._save_memory({})
80
81     def popitem(self, last: Optional[bool] = None) -> Any:
82         with self._modify_db() as db:
83             return db.popitem()
84
85     @synchronized(MPLOCK)
86     @contextmanager
87     def _modify_db(self) -> Generator:
88         db = self._read_memory()
89         yield db
90         self._save_memory(db)
91
92     def __getitem__(self, key: str) -> Any:
93         return self._read_memory()[key]
94
95     def __setitem__(self, key: str, value: Any) -> None:
96         with self._modify_db() as db:
97             db[key] = value
98
99     def __len__(self) -> int:
100         return len(self._read_memory())
101
102     def __delitem__(self, key: str) -> None:
103         with self._modify_db() as db:
104             del db[key]
105
106     def __iter__(self) -> Iterator:
107         return iter(self._read_memory())
108
109     def __reversed__(self):
110         return reversed(self._read_memory())
111
112     def __del__(self) -> None:
113         self.close()
114
115     def __contains__(self, key: str) -> bool:
116         return key in self._read_memory()
117
118     def __eq__(self, other: Any) -> bool:
119         return self._read_memory() == other
120
121     def __ne__(self, other: Any) -> bool:
122         return self._read_memory() != other
123
124     def __str__(self) -> str:
125         return str(self._read_memory())
126
127     def __repr__(self) -> str:
128         return repr(self._read_memory())
129
130     def get(self, key: str, default: Optional[Any] = None) -> Any:
131         return self._read_memory().get(key, default)
132
133     def keys(self) -> KeysView[Any]:
134         return self._read_memory().keys()
135
136     def values(self) -> ValuesView[Any]:
137         return self._read_memory().values()
138
139     def items(self) -> ItemsView:
140         return self._read_memory().items()
141
142     def pop(self, key: str, default: Optional[Any] = None):
143         with self._modify_db() as db:
144             if default is None:
145                 return db.pop(key)
146             return db.pop(key, default)
147
148     def update(self, other=(), /, **kwds):
149         with self._modify_db() as db:
150             db.update(other, **kwds)
151
152     def setdefault(self, key: str, default: Optional[Any] = None):
153         with self._modify_db() as db:
154             return db.setdefault(key, default)
155
156     def _save_memory(self, db: Dict[str, Any]) -> None:
157         data = self._serializer.dumps(db)
158         try:
159             self.shared_memory.buf[: len(data)] = data
160         except ValueError as exc:
161             raise ValueError("exceeds available storage") from exc
162
163     def _read_memory(self) -> Dict[str, Any]:
164         return self._serializer.loads(self.shared_memory.buf.tobytes())
165
166
167 if __name__ == '__main__':
168     import doctest
169
170     doctest.testmod()