projects
/
python_utils.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Add coding comments for files with utf8 characters in there.
[python_utils.git]
/
collect
/
shared_dict.py
diff --git
a/collect/shared_dict.py
b/collect/shared_dict.py
index 93aa452d50f9bb383d60662a3925ee51e438b644..e0a42f2c55c2fc865b0c89642d458ae26009c224 100644
(file)
--- a/
collect/shared_dict.py
+++ b/
collect/shared_dict.py
@@
-30,17
+30,8
@@
This class is based on https://github.com/luizalabs/shared-memory-dict
import pickle
from contextlib import contextmanager
from functools import wraps
import pickle
from contextlib import contextmanager
from functools import wraps
-from multiprocessing import shared_memory, Lock
-from typing import (
- Any,
- Dict,
- Generator,
- KeysView,
- ItemsView,
- Iterator,
- Optional,
- ValuesView,
-)
+from multiprocessing import RLock, shared_memory
+from typing import Any, Dict, Generator, ItemsView, Iterator, KeysView, Optional, ValuesView
from decorator_utils import synchronized
from decorator_utils import synchronized
@@
-64,30
+55,37
@@
class PickleSerializer:
class SharedDict(object):
NULL_BYTE = b'\x00'
class SharedDict(object):
NULL_BYTE = b'\x00'
- MPLOCK = Lock()
+ MPLOCK =
R
Lock()
def __init__(
self,
name: str,
def __init__(
self,
name: str,
- size
: int
,
+ size
_bytes: Optional[int] = None
,
) -> None:
super().__init__()
) -> None:
super().__init__()
+ self.name = name
self._serializer = PickleSerializer()
self._serializer = PickleSerializer()
- self.shared_memory = self._get_or_create_memory_block(name, size)
+ assert size_bytes is None or size_bytes > 0
+ self.shared_memory = self._get_or_create_memory_block(name, size_bytes)
self._ensure_memory_initialization()
self._ensure_memory_initialization()
+ self.lock = RLock()
+
+ def get_name(self):
+ return self.name
def _get_or_create_memory_block(
def _get_or_create_memory_block(
- self, name: str, size: int
+ self,
+ name: str,
+ size_bytes: Optional[int] = None,
) -> shared_memory.SharedMemory:
try:
return shared_memory.SharedMemory(name=name)
except FileNotFoundError:
) -> shared_memory.SharedMemory:
try:
return shared_memory.SharedMemory(name=name)
except FileNotFoundError:
- return shared_memory.SharedMemory(name=name, create=True, size=size)
+ assert size_bytes is not None
+ return shared_memory.SharedMemory(name=name, create=True, size=size_bytes)
def _ensure_memory_initialization(self):
def _ensure_memory_initialization(self):
- memory_is_empty = (
- bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
- )
+ memory_is_empty = bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b''
if memory_is_empty:
self.clear()
if memory_is_empty:
self.clear()
@@
-99,9
+97,9
@@
class SharedDict(object):
def cleanup(self) -> None:
if not hasattr(self, 'shared_memory'):
return
def cleanup(self) -> None:
if not hasattr(self, 'shared_memory'):
return
- self.shared_memory.unlink()
+ with SharedDict.MPLOCK:
+ self.shared_memory.unlink()
- @synchronized(MPLOCK)
def clear(self) -> None:
self._save_memory({})
def clear(self) -> None:
self._save_memory({})
@@
-109,12
+107,12
@@
class SharedDict(object):
with self._modify_db() as db:
return db.popitem()
with self._modify_db() as db:
return db.popitem()
- @synchronized(MPLOCK)
@contextmanager
def _modify_db(self) -> Generator:
@contextmanager
def _modify_db(self) -> Generator:
- db = self._read_memory()
- yield db
- self._save_memory(db)
+ with SharedDict.MPLOCK:
+ db = self._read_memory()
+ yield db
+ self._save_memory(db)
def __getitem__(self, key: str) -> Any:
return self._read_memory()[key]
def __getitem__(self, key: str) -> Any:
return self._read_memory()[key]
@@
-181,14
+179,16
@@
class SharedDict(object):
return db.setdefault(key, default)
def _save_memory(self, db: Dict[str, Any]) -> None:
return db.setdefault(key, default)
def _save_memory(self, db: Dict[str, Any]) -> None:
- data = self._serializer.dumps(db)
- try:
- self.shared_memory.buf[: len(data)] = data
- except ValueError as exc:
- raise ValueError("exceeds available storage") from exc
+ with SharedDict.MPLOCK:
+ data = self._serializer.dumps(db)
+ try:
+ self.shared_memory.buf[: len(data)] = data
+ except ValueError as exc:
+ raise ValueError("exceeds available storage") from exc
def _read_memory(self) -> Dict[str, Any]:
def _read_memory(self) -> Dict[str, Any]:
- return self._serializer.loads(self.shared_memory.buf.tobytes())
+ with SharedDict.MPLOCK:
+ return self._serializer.loads(self.shared_memory.buf.tobytes())
if __name__ == '__main__':
if __name__ == '__main__':