Change locking boundaries for shared dict. Add a unit test.
authorScott <[email protected]>
Mon, 31 Jan 2022 05:29:34 +0000 (21:29 -0800)
committerScott <[email protected]>
Mon, 31 Jan 2022 05:29:34 +0000 (21:29 -0800)
Make smart_futures re-raise exceptions that happened in futures.
Mess with file_utils.

collect/shared_dict.py
file_utils.py
smart_future.py
tests/shared_dict_test.py [new file with mode: 0755]

index 93aa452d50f9bb383d60662a3925ee51e438b644..0d8e7c2f7a36aa5ddb7c54c72aecddbf56df71c3 100644 (file)
@@ -30,7 +30,7 @@ This class is based on https://github.com/luizalabs/shared-memory-dict
 import pickle
 from contextlib import contextmanager
 from functools import wraps
-from multiprocessing import shared_memory, Lock
+from multiprocessing import shared_memory, RLock
 from typing import (
     Any,
     Dict,
@@ -64,25 +64,33 @@ class PickleSerializer:
 
 class SharedDict(object):
     NULL_BYTE = b'\x00'
-    MPLOCK = Lock()
+    MPLOCK = RLock()
 
     def __init__(
         self,
         name: str,
-        size: int,
+        size_bytes: Optional[int] = None,
     ) -> None:
         super().__init__()
+        self.name = name
         self._serializer = PickleSerializer()
-        self.shared_memory = self._get_or_create_memory_block(name, size)
+        self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
         self._ensure_memory_initialization()
+        self.lock = RLock()
+
+    def get_name(self):
+        return self.name
 
     def _get_or_create_memory_block(
-        self, name: str, size: int
+        self,
+        name: str,
+        size_bytes: Optional[int] = None,
     ) -> shared_memory.SharedMemory:
         try:
             return shared_memory.SharedMemory(name=name)
         except FileNotFoundError:
-            return shared_memory.SharedMemory(name=name, create=True, size=size)
+            assert size_bytes
+            return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
 
     def _ensure_memory_initialization(self):
         memory_is_empty = (
@@ -101,7 +109,6 @@ class SharedDict(object):
             return
         self.shared_memory.unlink()
 
-    @synchronized(MPLOCK)
     def clear(self) -> None:
         self._save_memory({})
 
@@ -109,12 +116,12 @@ class SharedDict(object):
         with self._modify_db() as db:
             return db.popitem()
 
-    @synchronized(MPLOCK)
     @contextmanager
     def _modify_db(self) -> Generator:
-        db = self._read_memory()
-        yield db
-        self._save_memory(db)
+        with SharedDict.MPLOCK:
+            db = self._read_memory()
+            yield db
+            self._save_memory(db)
 
     def __getitem__(self, key: str) -> Any:
         return self._read_memory()[key]
@@ -181,14 +188,16 @@ class SharedDict(object):
             return db.setdefault(key, default)
 
     def _save_memory(self, db: Dict[str, Any]) -> None:
-        data = self._serializer.dumps(db)
-        try:
-            self.shared_memory.buf[: len(data)] = data
-        except ValueError as exc:
-            raise ValueError("exceeds available storage") from exc
+        with SharedDict.MPLOCK:
+            data = self._serializer.dumps(db)
+            try:
+                self.shared_memory.buf[: len(data)] = data
+            except ValueError as exc:
+                raise ValueError("exceeds available storage") from exc
 
     def _read_memory(self) -> Dict[str, Any]:
-        return self._serializer.loads(self.shared_memory.buf.tobytes())
+        with SharedDict.MPLOCK:
+            return self._serializer.loads(self.shared_memory.buf.tobytes())
 
 
 if __name__ == '__main__':
index 22210e4444afcb5223fbfcf70dcfe839baa15edd..cd37f3069c70efd5c0f835e3362adbdf18d52e24 100644 (file)
@@ -33,15 +33,18 @@ def remove_hash_comments(x):
     return re.sub(r'#.*$', '', x)
 
 
-def read_file_to_list(
-    filename: str, *, skip_blank_lines=False, line_transformations=[]
+def slurp_file(
+    filename: str,
+    *,
+    skip_blank_lines=False,
+    line_transformers=[],
 ):
     ret = []
     if not file_is_readable(filename):
         raise Exception(f'{filename} can\'t be read.')
     with open(filename) as rf:
         for line in rf:
-            for transformation in line_transformations:
+            for transformation in line_transformers:
                 line = transformation(line)
             if skip_blank_lines and line == '':
                 continue
index c96c5a712d9f8d829eb73a21082f2ffbc0f41d48..2f3cbd9a9949f681e8a7cc70bd35ca43626e2861 100644 (file)
@@ -40,8 +40,11 @@ def wait_any(
             if log_exceptions and not f.cancelled():
                 exception = f.exception()
                 if exception is not None:
+                    logger.warning(
+                        f'Future {id(f)} raised an unhandled exception and exited.'
+                    )
                     logger.exception(exception)
-                    traceback.print_tb(exception.__traceback__)
+                    raise exception
             yield smart_future_by_real_future[f]
     if callback is not None:
         callback()
@@ -62,8 +65,11 @@ def wait_all(
             if not f.cancelled():
                 exception = f.exception()
                 if exception is not None:
+                    logger.warning(
+                        f'Future {id(f)} raised an unhandled exception and exited.'
+                    )
                     logger.exception(exception)
-                    traceback.print_tb(exception.__traceback__)
+                    raise exception
     assert len(done) == len(real_futures)
     assert len(not_done) == 0
 
diff --git a/tests/shared_dict_test.py b/tests/shared_dict_test.py
new file mode 100755 (executable)
index 0000000..c8294c5
--- /dev/null
@@ -0,0 +1,69 @@
+#!/usr/bin/env python3
+
+import unittest
+
+from collect.shared_dict import SharedDict
+import parallelize as p
+import smart_future
+import unittest_utils
+
+
+class SharedDictTest(unittest.TestCase):
+    @p.parallelize(method=p.Method.PROCESS)
+    def doit(self, n: int, dict_name: str):
+        d = SharedDict(dict_name)
+        try:
+            msg = f'Hello from shard {n}'
+            d[n] = msg
+            self.assertTrue(n in d)
+            self.assertEqual(msg, d[n])
+            return n
+        finally:
+            d.close()
+
+    def test_basic_operations(self):
+        dict_name = 'test_shared_dict'
+        d = SharedDict(dict_name, 4096)
+        try:
+            self.assertEqual(dict_name, d.get_name())
+            results = []
+            for n in range(100):
+                f = self.doit(n, d.get_name())
+                results.append(f)
+            smart_future.wait_all(results)
+            for f in results:
+                self.assertTrue(f.wrapped_future.done())
+            for k in d:
+                self.assertEqual(d[k], f'Hello from shard {k}')
+        finally:
+            d.close()
+            d.cleanup()
+
+    @p.parallelize(method=p.Method.PROCESS)
+    def add_one(self, name: str):
+        d = SharedDict(name)
+        try:
+            for x in range(1000):
+                with SharedDict.MPLOCK:
+                    d["sum"] += 1
+        finally:
+            d.close()
+
+    def test_locking_works(self):
+        dict_name = 'test_shared_dict_lock'
+        d = SharedDict(dict_name, 4096)
+        try:
+            d["sum"] = 0
+            results = []
+            for n in range(10):
+                f = self.add_one(d.get_name())
+                results.append(f)
+            smart_future.wait_all(results)
+            self.assertEqual(10000, d["sum"])
+        finally:
+            d.close()
+            d.cleanup()
+
+
+if __name__ == '__main__':
+    unittest.main()