from typing import (
Any,
Dict,
- Generator,
+ Hashable,
ItemsView,
Iterator,
KeysView,
Optional,
+ Tuple,
ValuesView,
)
class PickleSerializer:
- def dumps(self, obj: dict) -> bytes:
+ """A serializer that uses pickling. Used to read/write bytes in the shared
+ memory region and interpret them as a dict."""
+
+ def dumps(self, obj: Dict[Hashable, Any]) -> bytes:
try:
return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
except pickle.PicklingError as e:
- raise Exception(e)
+ raise Exception from e
- def loads(self, data: bytes) -> dict:
+ def loads(self, data: bytes) -> Dict[Hashable, Any]:
try:
return pickle.loads(data)
except pickle.UnpicklingError as e:
- raise Exception(e)
+ raise Exception from e
-# TODO: protobuf serializer?
+# TODOs: profile the serializers and figure out the fastest one. Can
+# we use a ChainMap to avoid the constant de/re-serialization of the
+# whole thing?
class SharedDict(object):
+ """This class emulates the dict container but uses a
+ Multiprocessing.SharedMemory region to back the dict such that it
+ can be read and written by multiple independent processes at the
+ same time. Because it constantly de/re-serializes the dict, it is
+ much slower than a normal dict.
+
+ """
+
NULL_BYTE = b'\x00'
- MPLOCK = RLock()
+ LOCK = RLock()
def __init__(
self,
- name: str,
+ name: Optional[str] = None,
size_bytes: Optional[int] = None,
) -> None:
- super().__init__()
- self.name = name
- self._serializer = PickleSerializer()
+ """
+ Creates or attaches a shared dictionary back by a SharedMemory buffer.
+ For create semantics, a unique name (string) and a max dictionary size
+ (expressed in bytes) must be provided. For attach semantics, these are
+ ignored.
+
+ The first process that creates the SharedDict is responsible for
+ (optionally) naming it and deciding the max size (in bytes) that
+ it may be. It does this via args to the c'tor.
+
+ Subsequent processes may safely omit name and size args.
+
+ """
assert size_bytes is None or size_bytes > 0
+ self._serializer = PickleSerializer()
self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
self._ensure_memory_initialization()
- self.lock = RLock()
+ self.name = self.shared_memory.name
def get_name(self):
+ """Returns the name of the shared memory buffer backing the dict."""
return self.name
def _get_or_create_memory_block(
self,
- name: str,
+ name: Optional[str] = None,
size_bytes: Optional[int] = None,
) -> shared_memory.SharedMemory:
try:
return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
def _ensure_memory_initialization(self):
- memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
- if memory_is_empty:
- self.clear()
+ with SharedDict.LOCK:
+ memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
+ if memory_is_empty:
+ self.clear()
+
+ def _write_memory(self, db: Dict[Hashable, Any]) -> None:
+ data = self._serializer.dumps(db)
+ with SharedDict.LOCK:
+ try:
+ self.shared_memory.buf[: len(data)] = data
+ except ValueError as e:
+ raise ValueError("exceeds available storage") from e
+
+ def _read_memory(self) -> Dict[Hashable, Any]:
+ with SharedDict.LOCK:
+ return self._serializer.loads(self.shared_memory.buf.tobytes())
+
+ @contextmanager
+ def _modify_dict(self):
+ with SharedDict.LOCK:
+ db = self._read_memory()
+ yield db
+ self._write_memory(db)
def close(self) -> None:
+ """Unmap the shared dict and memory behind it from this
+ process. Called by automatically __del__"""
if not hasattr(self, 'shared_memory'):
return
self.shared_memory.close()
def cleanup(self) -> None:
+ """Unlink the shared dict and memory behind it. Only the last process should
+ invoke this. Not called automatically."""
if not hasattr(self, 'shared_memory'):
return
- with SharedDict.MPLOCK:
+ with SharedDict.LOCK:
self.shared_memory.unlink()
def clear(self) -> None:
- self._save_memory({})
-
- def popitem(self, last: Optional[bool] = None) -> Any:
- with self._modify_db() as db:
- return db.popitem()
+ """Clear the dict."""
+ self._write_memory({})
- @contextmanager
- def _modify_db(self) -> Generator:
- with SharedDict.MPLOCK:
- db = self._read_memory()
- yield db
- self._save_memory(db)
+ def copy(self) -> Dict[Hashable, Any]:
+ """Returns a shallow copy of the dict."""
+ return self._read_memory()
- def __getitem__(self, key: str) -> Any:
+ def __getitem__(self, key: Hashable) -> Any:
return self._read_memory()[key]
- def __setitem__(self, key: str, value: Any) -> None:
- with self._modify_db() as db:
+ def __setitem__(self, key: Hashable, value: Any) -> None:
+ with self._modify_dict() as db:
db[key] = value
def __len__(self) -> int:
return len(self._read_memory())
- def __delitem__(self, key: str) -> None:
- with self._modify_db() as db:
+ def __delitem__(self, key: Hashable) -> None:
+ with self._modify_dict() as db:
del db[key]
- def __iter__(self) -> Iterator:
+ def __iter__(self) -> Iterator[Hashable]:
return iter(self._read_memory())
- def __reversed__(self):
+ def __reversed__(self) -> Iterator[Hashable]:
return reversed(self._read_memory())
def __del__(self) -> None:
self.close()
- def __contains__(self, key: str) -> bool:
+ def __contains__(self, key: Hashable) -> bool:
return key in self._read_memory()
def __eq__(self, other: Any) -> bool:
return repr(self._read_memory())
def get(self, key: str, default: Optional[Any] = None) -> Any:
+ """Gets the value associated with key or a default."""
return self._read_memory().get(key, default)
- def keys(self) -> KeysView[Any]:
+ def keys(self) -> KeysView[Hashable]:
return self._read_memory().keys()
def values(self) -> ValuesView[Any]:
return self._read_memory().values()
- def items(self) -> ItemsView:
+ def items(self) -> ItemsView[Hashable, Any]:
return self._read_memory().items()
- def pop(self, key: str, default: Optional[Any] = None):
- with self._modify_db() as db:
+ def popitem(self) -> Tuple[Hashable, Any]:
+ """Remove and return the last added item."""
+ with self._modify_dict() as db:
+ return db.popitem()
+
+ def pop(self, key: Hashable, default: Optional[Any] = None) -> Any:
+ """Remove and return the value associated with key or a default"""
+ with self._modify_dict() as db:
if default is None:
return db.pop(key)
return db.pop(key, default)
def update(self, other=(), /, **kwds):
- with self._modify_db() as db:
+ with self._modify_dict() as db:
db.update(other, **kwds)
- def setdefault(self, key: str, default: Optional[Any] = None):
- with self._modify_db() as db:
+ def setdefault(self, key: Hashable, default: Optional[Any] = None):
+ with self._modify_dict() as db:
return db.setdefault(key, default)
- def _save_memory(self, db: Dict[str, Any]) -> None:
- with SharedDict.MPLOCK:
- data = self._serializer.dumps(db)
- try:
- self.shared_memory.buf[: len(data)] = data
- except ValueError as exc:
- raise ValueError("exceeds available storage") from exc
-
- def _read_memory(self) -> Dict[str, Any]:
- with SharedDict.MPLOCK:
- return self._serializer.loads(self.shared_memory.buf.tobytes())
-
if __name__ == '__main__':
import doctest
"""shared_dict unittest."""
+import random
import unittest
import parallelize as p
class SharedDictTest(unittest.TestCase):
@p.parallelize(method=p.Method.PROCESS)
- def doit(self, n: int, dict_name: str):
- d = SharedDict(dict_name)
+ def doit(self, n: int, dict_name: str, parent_lock_id: int):
+ assert id(SharedDict.LOCK) == parent_lock_id
+ d = SharedDict(dict_name, None)
try:
msg = f'Hello from shard {n}'
- d[n] = msg
- self.assertTrue(n in d)
- self.assertEqual(msg, d[n])
+ for x in range(0, 1000):
+ d[n] = msg
+ self.assertTrue(n in d)
+ self.assertEqual(msg, d[n])
+ y = d.get(random.randrange(0, 99), None)
return n
finally:
d.close()
self.assertEqual(dict_name, d.get_name())
results = []
for n in range(100):
- f = self.doit(n, d.get_name())
+ f = self.doit(n, d.get_name(), id(SharedDict.LOCK))
results.append(f)
smart_future.wait_all(results)
for f in results:
self.assertTrue(f.wrapped_future.done())
for k in d:
self.assertEqual(d[k], f'Hello from shard {k}')
+ assert len(d) == 100
finally:
d.close()
d.cleanup()
@p.parallelize(method=p.Method.PROCESS)
- def add_one(self, name: str):
+ def add_one(self, name: str, expected_lock_id: int):
d = SharedDict(name)
+ self.assertEqual(id(SharedDict.LOCK), expected_lock_id)
try:
for x in range(1000):
- with SharedDict.MPLOCK:
+ with SharedDict.LOCK:
d["sum"] += 1
finally:
d.close()
d["sum"] = 0
results = []
for n in range(10):
- f = self.add_one(d.get_name())
+ f = self.add_one(d.get_name(), id(SharedDict.LOCK))
results.append(f)
smart_future.wait_all(results)
self.assertEqual(10000, d["sum"])