From: Scott Date: Mon, 31 Jan 2022 05:29:34 +0000 (-0800) Subject: Change locking boundaries for shared dict. Add a unit test. X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=865825894beeedd47d26dd092d40bfee582f5475;p=python_utils.git Change locking boundaries for shared dict. Add a unit test. Make smart_futures re-raise exceptions that happened in futures. Mess with file_utils. --- diff --git a/collect/shared_dict.py b/collect/shared_dict.py index 93aa452..0d8e7c2 100644 --- a/collect/shared_dict.py +++ b/collect/shared_dict.py @@ -30,7 +30,7 @@ This class is based on https://github.com/luizalabs/shared-memory-dict import pickle from contextlib import contextmanager from functools import wraps -from multiprocessing import shared_memory, Lock +from multiprocessing import shared_memory, RLock from typing import ( Any, Dict, @@ -64,25 +64,33 @@ class PickleSerializer: class SharedDict(object): NULL_BYTE = b'\x00' - MPLOCK = Lock() + MPLOCK = RLock() def __init__( self, name: str, - size: int, + size_bytes: Optional[int] = None, ) -> None: super().__init__() + self.name = name self._serializer = PickleSerializer() - self.shared_memory = self._get_or_create_memory_block(name, size) + self.shared_memory = self._get_or_create_memory_block(name, size_bytes) self._ensure_memory_initialization() + self.lock = RLock() + + def get_name(self): + return self.name def _get_or_create_memory_block( - self, name: str, size: int + self, + name: str, + size_bytes: Optional[int] = None, ) -> shared_memory.SharedMemory: try: return shared_memory.SharedMemory(name=name) except FileNotFoundError: - return shared_memory.SharedMemory(name=name, create=True, size=size) + assert size_bytes + return shared_memory.SharedMemory(name=name, create=True, size=size_bytes) def _ensure_memory_initialization(self): memory_is_empty = ( @@ -101,7 +109,6 @@ class SharedDict(object): return self.shared_memory.unlink() - @synchronized(MPLOCK) def clear(self) -> None: self._save_memory({}) @@ -109,12 +116,12 @@ class SharedDict(object): with self._modify_db() as db: return db.popitem() - @synchronized(MPLOCK) @contextmanager def _modify_db(self) -> Generator: - db = self._read_memory() - yield db - self._save_memory(db) + with SharedDict.MPLOCK: + db = self._read_memory() + yield db + self._save_memory(db) def __getitem__(self, key: str) -> Any: return self._read_memory()[key] @@ -181,14 +188,16 @@ class SharedDict(object): return db.setdefault(key, default) def _save_memory(self, db: Dict[str, Any]) -> None: - data = self._serializer.dumps(db) - try: - self.shared_memory.buf[: len(data)] = data - except ValueError as exc: - raise ValueError("exceeds available storage") from exc + with SharedDict.MPLOCK: + data = self._serializer.dumps(db) + try: + self.shared_memory.buf[: len(data)] = data + except ValueError as exc: + raise ValueError("exceeds available storage") from exc def _read_memory(self) -> Dict[str, Any]: - return self._serializer.loads(self.shared_memory.buf.tobytes()) + with SharedDict.MPLOCK: + return self._serializer.loads(self.shared_memory.buf.tobytes()) if __name__ == '__main__': diff --git a/file_utils.py b/file_utils.py index 22210e4..cd37f30 100644 --- a/file_utils.py +++ b/file_utils.py @@ -33,15 +33,18 @@ def remove_hash_comments(x): return re.sub(r'#.*$', '', x) -def read_file_to_list( - filename: str, *, skip_blank_lines=False, line_transformations=[] +def slurp_file( + filename: str, + *, + skip_blank_lines=False, + line_transformers=[], ): ret = [] if not file_is_readable(filename): raise Exception(f'{filename} can\'t be read.') with open(filename) as rf: for line in rf: - for transformation in line_transformations: + for transformation in line_transformers: line = transformation(line) if skip_blank_lines and line == '': continue diff --git a/smart_future.py b/smart_future.py index c96c5a7..2f3cbd9 100644 --- a/smart_future.py +++ b/smart_future.py @@ -40,8 +40,11 @@ def wait_any( if log_exceptions and not f.cancelled(): exception = f.exception() if exception is not None: + logger.warning( + f'Future {id(f)} raised an unhandled exception and exited.' + ) logger.exception(exception) - traceback.print_tb(exception.__traceback__) + raise exception yield smart_future_by_real_future[f] if callback is not None: callback() @@ -62,8 +65,11 @@ def wait_all( if not f.cancelled(): exception = f.exception() if exception is not None: + logger.warning( + f'Future {id(f)} raised an unhandled exception and exited.' + ) logger.exception(exception) - traceback.print_tb(exception.__traceback__) + raise exception assert len(done) == len(real_futures) assert len(not_done) == 0 diff --git a/tests/shared_dict_test.py b/tests/shared_dict_test.py new file mode 100755 index 0000000..c8294c5 --- /dev/null +++ b/tests/shared_dict_test.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +import unittest + +from collect.shared_dict import SharedDict +import parallelize as p +import smart_future +import unittest_utils + + +class SharedDictTest(unittest.TestCase): + @p.parallelize(method=p.Method.PROCESS) + def doit(self, n: int, dict_name: str): + d = SharedDict(dict_name) + try: + msg = f'Hello from shard {n}' + d[n] = msg + self.assertTrue(n in d) + self.assertEqual(msg, d[n]) + return n + finally: + d.close() + + def test_basic_operations(self): + dict_name = 'test_shared_dict' + d = SharedDict(dict_name, 4096) + try: + self.assertEqual(dict_name, d.get_name()) + results = [] + for n in range(100): + f = self.doit(n, d.get_name()) + results.append(f) + smart_future.wait_all(results) + for f in results: + self.assertTrue(f.wrapped_future.done()) + for k in d: + self.assertEqual(d[k], f'Hello from shard {k}') + finally: + d.close() + d.cleanup() + + @p.parallelize(method=p.Method.PROCESS) + def add_one(self, name: str): + d = SharedDict(name) + try: + for x in range(1000): + with SharedDict.MPLOCK: + d["sum"] += 1 + finally: + d.close() + + def test_locking_works(self): + dict_name = 'test_shared_dict_lock' + d = SharedDict(dict_name, 4096) + try: + d["sum"] = 0 + results = [] + for n in range(10): + f = self.add_one(d.get_name()) + results.append(f) + smart_future.wait_all(results) + self.assertEqual(10000, d["sum"]) + finally: + d.close() + d.cleanup() + + +if __name__ == '__main__': + unittest.main()