From 4d69e7bee5985ff744ef543f9c9f20bb37f4fd6d Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Tue, 30 Aug 2022 15:36:23 -0700 Subject: [PATCH] Changes towards splitting up the library and (maybe?) publishing on PyPi. --- NOTICE | 5 +- collect/shared_dict.py | 149 ++++++++++++++++++++++++-------------- config.py | 51 +++++++++++-- tests/shared_dict_test.py | 24 +++--- 4 files changed, 158 insertions(+), 71 deletions(-) diff --git a/NOTICE b/NOTICE index 59d5c11..d61aae6 100644 --- a/NOTICE +++ b/NOTICE @@ -41,9 +41,10 @@ contains URLs pointing at the source of the forked code. Scott's modifications include: + Adding a unittest (tests/shared_dict_test.py), + + Added type hints, + + Changes to locking scope, + Minor cleanup and style tweaks, - + Added sphinx style pydocs, - + Added type hints. + + Added sphinx style pydocs. 3. The timeout decortator in decorator_utils.py is based on original work published in ActiveState code recipes and covered by the PSF diff --git a/collect/shared_dict.py b/collect/shared_dict.py index 3207927..dccae4c 100644 --- a/collect/shared_dict.py +++ b/collect/shared_dict.py @@ -33,55 +33,81 @@ from multiprocessing import RLock, shared_memory from typing import ( Any, Dict, - Generator, + Hashable, ItemsView, Iterator, KeysView, Optional, + Tuple, ValuesView, ) class PickleSerializer: - def dumps(self, obj: dict) -> bytes: + """A serializer that uses pickling. Used to read/write bytes in the shared + memory region and interpret them as a dict.""" + + def dumps(self, obj: Dict[Hashable, Any]) -> bytes: try: return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) except pickle.PicklingError as e: - raise Exception(e) + raise Exception from e - def loads(self, data: bytes) -> dict: + def loads(self, data: bytes) -> Dict[Hashable, Any]: try: return pickle.loads(data) except pickle.UnpicklingError as e: - raise Exception(e) + raise Exception from e -# TODO: protobuf serializer? +# TODOs: profile the serializers and figure out the fastest one. Can +# we use a ChainMap to avoid the constant de/re-serialization of the +# whole thing? class SharedDict(object): + """This class emulates the dict container but uses a + Multiprocessing.SharedMemory region to back the dict such that it + can be read and written by multiple independent processes at the + same time. Because it constantly de/re-serializes the dict, it is + much slower than a normal dict. + + """ + NULL_BYTE = b'\x00' - MPLOCK = RLock() + LOCK = RLock() def __init__( self, - name: str, + name: Optional[str] = None, size_bytes: Optional[int] = None, ) -> None: - super().__init__() - self.name = name - self._serializer = PickleSerializer() + """ + Creates or attaches a shared dictionary back by a SharedMemory buffer. + For create semantics, a unique name (string) and a max dictionary size + (expressed in bytes) must be provided. For attach semantics, these are + ignored. + + The first process that creates the SharedDict is responsible for + (optionally) naming it and deciding the max size (in bytes) that + it may be. It does this via args to the c'tor. + + Subsequent processes may safely omit name and size args. + + """ assert size_bytes is None or size_bytes > 0 + self._serializer = PickleSerializer() self.shared_memory = self._get_or_create_memory_block(name, size_bytes) self._ensure_memory_initialization() - self.lock = RLock() + self.name = self.shared_memory.name def get_name(self): + """Returns the name of the shared memory buffer backing the dict.""" return self.name def _get_or_create_memory_block( self, - name: str, + name: Optional[str] = None, size_bytes: Optional[int] = None, ) -> shared_memory.SharedMemory: try: @@ -91,59 +117,77 @@ class SharedDict(object): return shared_memory.SharedMemory(name=name, create=True, size=size_bytes) def _ensure_memory_initialization(self): - memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b'' - if memory_is_empty: - self.clear() + with SharedDict.LOCK: + memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b'' + if memory_is_empty: + self.clear() + + def _write_memory(self, db: Dict[Hashable, Any]) -> None: + data = self._serializer.dumps(db) + with SharedDict.LOCK: + try: + self.shared_memory.buf[: len(data)] = data + except ValueError as e: + raise ValueError("exceeds available storage") from e + + def _read_memory(self) -> Dict[Hashable, Any]: + with SharedDict.LOCK: + return self._serializer.loads(self.shared_memory.buf.tobytes()) + + @contextmanager + def _modify_dict(self): + with SharedDict.LOCK: + db = self._read_memory() + yield db + self._write_memory(db) def close(self) -> None: + """Unmap the shared dict and memory behind it from this + process. Called by automatically __del__""" if not hasattr(self, 'shared_memory'): return self.shared_memory.close() def cleanup(self) -> None: + """Unlink the shared dict and memory behind it. Only the last process should + invoke this. Not called automatically.""" if not hasattr(self, 'shared_memory'): return - with SharedDict.MPLOCK: + with SharedDict.LOCK: self.shared_memory.unlink() def clear(self) -> None: - self._save_memory({}) - - def popitem(self, last: Optional[bool] = None) -> Any: - with self._modify_db() as db: - return db.popitem() + """Clear the dict.""" + self._write_memory({}) - @contextmanager - def _modify_db(self) -> Generator: - with SharedDict.MPLOCK: - db = self._read_memory() - yield db - self._save_memory(db) + def copy(self) -> Dict[Hashable, Any]: + """Returns a shallow copy of the dict.""" + return self._read_memory() - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: Hashable) -> Any: return self._read_memory()[key] - def __setitem__(self, key: str, value: Any) -> None: - with self._modify_db() as db: + def __setitem__(self, key: Hashable, value: Any) -> None: + with self._modify_dict() as db: db[key] = value def __len__(self) -> int: return len(self._read_memory()) - def __delitem__(self, key: str) -> None: - with self._modify_db() as db: + def __delitem__(self, key: Hashable) -> None: + with self._modify_dict() as db: del db[key] - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[Hashable]: return iter(self._read_memory()) - def __reversed__(self): + def __reversed__(self) -> Iterator[Hashable]: return reversed(self._read_memory()) def __del__(self) -> None: self.close() - def __contains__(self, key: str) -> bool: + def __contains__(self, key: Hashable) -> bool: return key in self._read_memory() def __eq__(self, other: Any) -> bool: @@ -159,43 +203,38 @@ class SharedDict(object): return repr(self._read_memory()) def get(self, key: str, default: Optional[Any] = None) -> Any: + """Gets the value associated with key or a default.""" return self._read_memory().get(key, default) - def keys(self) -> KeysView[Any]: + def keys(self) -> KeysView[Hashable]: return self._read_memory().keys() def values(self) -> ValuesView[Any]: return self._read_memory().values() - def items(self) -> ItemsView: + def items(self) -> ItemsView[Hashable, Any]: return self._read_memory().items() - def pop(self, key: str, default: Optional[Any] = None): - with self._modify_db() as db: + def popitem(self) -> Tuple[Hashable, Any]: + """Remove and return the last added item.""" + with self._modify_dict() as db: + return db.popitem() + + def pop(self, key: Hashable, default: Optional[Any] = None) -> Any: + """Remove and return the value associated with key or a default""" + with self._modify_dict() as db: if default is None: return db.pop(key) return db.pop(key, default) def update(self, other=(), /, **kwds): - with self._modify_db() as db: + with self._modify_dict() as db: db.update(other, **kwds) - def setdefault(self, key: str, default: Optional[Any] = None): - with self._modify_db() as db: + def setdefault(self, key: Hashable, default: Optional[Any] = None): + with self._modify_dict() as db: return db.setdefault(key, default) - def _save_memory(self, db: Dict[str, Any]) -> None: - with SharedDict.MPLOCK: - data = self._serializer.dumps(db) - try: - self.shared_memory.buf[: len(data)] = data - except ValueError as exc: - raise ValueError("exceeds available storage") from exc - - def _read_memory(self) -> Dict[str, Any]: - with SharedDict.MPLOCK: - return self._serializer.loads(self.shared_memory.buf.tobytes()) - if __name__ == '__main__': import doctest diff --git a/config.py b/config.py index a98db57..4d88514 100644 --- a/config.py +++ b/config.py @@ -90,8 +90,6 @@ import re import sys from typing import Any, Dict, List, Optional, Tuple -import scott_secrets - # This module is commonly used by others in here and should avoid # taking any unnecessary dependencies back on them. @@ -330,6 +328,46 @@ class Config: env = env[1:] return var, env, chunks + @staticmethod + def _to_bool(in_str: str) -> bool: + """ + Args: + in_str: the string to convert to boolean + + Returns: + A boolean equivalent of the original string based on its contents. + All conversion is case insensitive. A positive boolean (True) is + returned if the string value is any of the following: + + * "true" + * "t" + * "1" + * "yes" + * "y" + * "on" + + Otherwise False is returned. + + >>> to_bool('True') + True + + >>> to_bool('1') + True + + >>> to_bool('yes') + True + + >>> to_bool('no') + False + + >>> to_bool('huh?') + False + + >>> to_bool('on') + True + """ + return in_str.lower() in ("true", "1", "yes", "y", "t", "on") + def _augment_sys_argv_from_environment_variables(self): """Internal. Look at the system environment for variables that match commandline arg names. This is done via some munging such that: @@ -366,9 +404,7 @@ class Config: self.saved_messages.append( f'Initialized from environment: {var} = {value}' ) - from string_utils import to_bool - - if len(chunks) == 1 and to_bool(value): + if len(chunks) == 1 and Config._to_bool(value): sys.argv.append(var) elif len(chunks) > 1: sys.argv.append(var) @@ -421,8 +457,11 @@ class Config: if loadfile[:3] == 'zk:': from kazoo.client import KazooClient + import scott_secrets + try: if self.zk is None: + self.zk = KazooClient( hosts=scott_secrets.ZOOKEEPER_NODES, use_ssl=True, @@ -545,6 +584,8 @@ class Config: if not self.zk: from kazoo.client import KazooClient + import scott_secrets + self.zk = KazooClient( hosts=scott_secrets.ZOOKEEPER_NODES, use_ssl=True, diff --git a/tests/shared_dict_test.py b/tests/shared_dict_test.py index 230bdb9..68a3788 100755 --- a/tests/shared_dict_test.py +++ b/tests/shared_dict_test.py @@ -4,6 +4,7 @@ """shared_dict unittest.""" +import random import unittest import parallelize as p @@ -14,13 +15,16 @@ from collect.shared_dict import SharedDict class SharedDictTest(unittest.TestCase): @p.parallelize(method=p.Method.PROCESS) - def doit(self, n: int, dict_name: str): - d = SharedDict(dict_name) + def doit(self, n: int, dict_name: str, parent_lock_id: int): + assert id(SharedDict.LOCK) == parent_lock_id + d = SharedDict(dict_name, None) try: msg = f'Hello from shard {n}' - d[n] = msg - self.assertTrue(n in d) - self.assertEqual(msg, d[n]) + for x in range(0, 1000): + d[n] = msg + self.assertTrue(n in d) + self.assertEqual(msg, d[n]) + y = d.get(random.randrange(0, 99), None) return n finally: d.close() @@ -32,23 +36,25 @@ class SharedDictTest(unittest.TestCase): self.assertEqual(dict_name, d.get_name()) results = [] for n in range(100): - f = self.doit(n, d.get_name()) + f = self.doit(n, d.get_name(), id(SharedDict.LOCK)) results.append(f) smart_future.wait_all(results) for f in results: self.assertTrue(f.wrapped_future.done()) for k in d: self.assertEqual(d[k], f'Hello from shard {k}') + assert len(d) == 100 finally: d.close() d.cleanup() @p.parallelize(method=p.Method.PROCESS) - def add_one(self, name: str): + def add_one(self, name: str, expected_lock_id: int): d = SharedDict(name) + self.assertEqual(id(SharedDict.LOCK), expected_lock_id) try: for x in range(1000): - with SharedDict.MPLOCK: + with SharedDict.LOCK: d["sum"] += 1 finally: d.close() @@ -60,7 +66,7 @@ class SharedDictTest(unittest.TestCase): d["sum"] = 0 results = [] for n in range(10): - f = self.add_one(d.get_name()) + f = self.add_one(d.get_name(), id(SharedDict.LOCK)) results.append(f) smart_future.wait_all(results) self.assertEqual(10000, d["sum"]) -- 2.45.0