4 from contextlib import contextmanager
5 from functools import wraps
6 from multiprocessing import shared_memory, Lock
18 from decorator_utils import synchronized
21 class PickleSerializer:
22 def dumps(self, obj: dict) -> bytes:
24 return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
25 except pickle.PicklingError as e:
28 def loads(self, data: bytes) -> dict:
30 return pickle.loads(data)
31 except pickle.UnpicklingError as e:
35 # TODO: protobuf serializer?
38 class SharedDict(object):
48 self._serializer = PickleSerializer()
49 self.shared_memory = self._get_or_create_memory_block(name, size)
50 self._ensure_memory_initialization()
52 def _get_or_create_memory_block(
53 self, name: str, size: int
54 ) -> shared_memory.SharedMemory:
56 return shared_memory.SharedMemory(name=name)
57 except FileNotFoundError:
58 return shared_memory.SharedMemory(name=name, create=True, size=size)
60 def _ensure_memory_initialization(self):
62 bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
67 def close(self) -> None:
68 if not hasattr(self, 'shared_memory'):
70 self.shared_memory.close()
72 def cleanup(self) -> None:
73 if not hasattr(self, 'shared_memory'):
75 self.shared_memory.unlink()
78 def clear(self) -> None:
81 def popitem(self, last: Optional[bool] = None) -> Any:
82 with self._modify_db() as db:
87 def _modify_db(self) -> Generator:
88 db = self._read_memory()
92 def __getitem__(self, key: str) -> Any:
93 return self._read_memory()[key]
95 def __setitem__(self, key: str, value: Any) -> None:
96 with self._modify_db() as db:
99 def __len__(self) -> int:
100 return len(self._read_memory())
102 def __delitem__(self, key: str) -> None:
103 with self._modify_db() as db:
106 def __iter__(self) -> Iterator:
107 return iter(self._read_memory())
109 def __reversed__(self):
110 return reversed(self._read_memory())
112 def __del__(self) -> None:
115 def __contains__(self, key: str) -> bool:
116 return key in self._read_memory()
118 def __eq__(self, other: Any) -> bool:
119 return self._read_memory() == other
121 def __ne__(self, other: Any) -> bool:
122 return self._read_memory() != other
124 def __str__(self) -> str:
125 return str(self._read_memory())
127 def __repr__(self) -> str:
128 return repr(self._read_memory())
130 def get(self, key: str, default: Optional[Any] = None) -> Any:
131 return self._read_memory().get(key, default)
133 def keys(self) -> KeysView[Any]:
134 return self._read_memory().keys()
136 def values(self) -> ValuesView[Any]:
137 return self._read_memory().values()
139 def items(self) -> ItemsView:
140 return self._read_memory().items()
142 def pop(self, key: str, default: Optional[Any] = None):
143 with self._modify_db() as db:
146 return db.pop(key, default)
148 def update(self, other=(), /, **kwds):
149 with self._modify_db() as db:
150 db.update(other, **kwds)
152 def setdefault(self, key: str, default: Optional[Any] = None):
153 with self._modify_db() as db:
154 return db.setdefault(key, default)
156 def _save_memory(self, db: Dict[str, Any]) -> None:
157 data = self._serializer.dumps(db)
159 self.shared_memory.buf[: len(data)] = data
160 except ValueError as exc:
161 raise ValueError("exceeds available storage") from exc
163 def _read_memory(self) -> Dict[str, Any]:
164 return self._serializer.loads(self.shared_memory.buf.tobytes())
167 if __name__ == '__main__':