Make smart_futures re-raise exceptions that happened in futures.
Mess with file_utils.
import pickle
from contextlib import contextmanager
from functools import wraps
-from multiprocessing import shared_memory, Lock
+from multiprocessing import shared_memory, RLock
from typing import (
Any,
Dict,
class SharedDict(object):
NULL_BYTE = b'\x00'
- MPLOCK = Lock()
+ MPLOCK = RLock()
def __init__(
self,
name: str,
- size: int,
+ size_bytes: Optional[int] = None,
) -> None:
super().__init__()
+ self.name = name
self._serializer = PickleSerializer()
- self.shared_memory = self._get_or_create_memory_block(name, size)
+ self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
self._ensure_memory_initialization()
+ self.lock = RLock()
+
+ def get_name(self):
+ return self.name
def _get_or_create_memory_block(
- self, name: str, size: int
+ self,
+ name: str,
+ size_bytes: Optional[int] = None,
) -> shared_memory.SharedMemory:
try:
return shared_memory.SharedMemory(name=name)
except FileNotFoundError:
- return shared_memory.SharedMemory(name=name, create=True, size=size)
+ assert size_bytes
+ return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
def _ensure_memory_initialization(self):
memory_is_empty = (
return
self.shared_memory.unlink()
- @synchronized(MPLOCK)
def clear(self) -> None:
self._save_memory({})
with self._modify_db() as db:
return db.popitem()
- @synchronized(MPLOCK)
@contextmanager
def _modify_db(self) -> Generator:
- db = self._read_memory()
- yield db
- self._save_memory(db)
+ with SharedDict.MPLOCK:
+ db = self._read_memory()
+ yield db
+ self._save_memory(db)
def __getitem__(self, key: str) -> Any:
return self._read_memory()[key]
return db.setdefault(key, default)
def _save_memory(self, db: Dict[str, Any]) -> None:
- data = self._serializer.dumps(db)
- try:
- self.shared_memory.buf[: len(data)] = data
- except ValueError as exc:
- raise ValueError("exceeds available storage") from exc
+ with SharedDict.MPLOCK:
+ data = self._serializer.dumps(db)
+ try:
+ self.shared_memory.buf[: len(data)] = data
+ except ValueError as exc:
+ raise ValueError("exceeds available storage") from exc
def _read_memory(self) -> Dict[str, Any]:
- return self._serializer.loads(self.shared_memory.buf.tobytes())
+ with SharedDict.MPLOCK:
+ return self._serializer.loads(self.shared_memory.buf.tobytes())
if __name__ == '__main__':
return re.sub(r'#.*$', '', x)
-def read_file_to_list(
- filename: str, *, skip_blank_lines=False, line_transformations=[]
+def slurp_file(
+ filename: str,
+ *,
+ skip_blank_lines=False,
+ line_transformers=[],
):
ret = []
if not file_is_readable(filename):
raise Exception(f'{filename} can\'t be read.')
with open(filename) as rf:
for line in rf:
- for transformation in line_transformations:
+ for transformation in line_transformers:
line = transformation(line)
if skip_blank_lines and line == '':
continue
if log_exceptions and not f.cancelled():
exception = f.exception()
if exception is not None:
+ logger.warning(
+ f'Future {id(f)} raised an unhandled exception and exited.'
+ )
logger.exception(exception)
- traceback.print_tb(exception.__traceback__)
+ raise exception
yield smart_future_by_real_future[f]
if callback is not None:
callback()
if not f.cancelled():
exception = f.exception()
if exception is not None:
+ logger.warning(
+ f'Future {id(f)} raised an unhandled exception and exited.'
+ )
logger.exception(exception)
- traceback.print_tb(exception.__traceback__)
+ raise exception
assert len(done) == len(real_futures)
assert len(not_done) == 0
--- /dev/null
+#!/usr/bin/env python3
+
+import unittest
+
+from collect.shared_dict import SharedDict
+import parallelize as p
+import smart_future
+import unittest_utils
+
+
+class SharedDictTest(unittest.TestCase):
+ @p.parallelize(method=p.Method.PROCESS)
+ def doit(self, n: int, dict_name: str):
+ d = SharedDict(dict_name)
+ try:
+ msg = f'Hello from shard {n}'
+ d[n] = msg
+ self.assertTrue(n in d)
+ self.assertEqual(msg, d[n])
+ return n
+ finally:
+ d.close()
+
+ def test_basic_operations(self):
+ dict_name = 'test_shared_dict'
+ d = SharedDict(dict_name, 4096)
+ try:
+ self.assertEqual(dict_name, d.get_name())
+ results = []
+ for n in range(100):
+ f = self.doit(n, d.get_name())
+ results.append(f)
+ smart_future.wait_all(results)
+ for f in results:
+ self.assertTrue(f.wrapped_future.done())
+ for k in d:
+ self.assertEqual(d[k], f'Hello from shard {k}')
+ finally:
+ d.close()
+ d.cleanup()
+
+ @p.parallelize(method=p.Method.PROCESS)
+ def add_one(self, name: str):
+ d = SharedDict(name)
+ try:
+ for x in range(1000):
+ with SharedDict.MPLOCK:
+ d["sum"] += 1
+ finally:
+ d.close()
+
+ def test_locking_works(self):
+ dict_name = 'test_shared_dict_lock'
+ d = SharedDict(dict_name, 4096)
+ try:
+ d["sum"] = 0
+ results = []
+ for n in range(10):
+ f = self.add_one(d.get_name())
+ results.append(f)
+ smart_future.wait_all(results)
+ self.assertEqual(10000, d["sum"])
+ finally:
+ d.close()
+ d.cleanup()
+
+
+if __name__ == '__main__':
+ unittest.main()