230bdb989c4c5e74c6d31cecebee3ea468b277c8
[python_utils.git] / tests / shared_dict_test.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """shared_dict unittest."""
6
7 import unittest
8
9 import parallelize as p
10 import smart_future
11 import unittest_utils
12 from collect.shared_dict import SharedDict
13
14
15 class SharedDictTest(unittest.TestCase):
16     @p.parallelize(method=p.Method.PROCESS)
17     def doit(self, n: int, dict_name: str):
18         d = SharedDict(dict_name)
19         try:
20             msg = f'Hello from shard {n}'
21             d[n] = msg
22             self.assertTrue(n in d)
23             self.assertEqual(msg, d[n])
24             return n
25         finally:
26             d.close()
27
28     def test_basic_operations(self):
29         dict_name = 'test_shared_dict'
30         d = SharedDict(dict_name, 4096)
31         try:
32             self.assertEqual(dict_name, d.get_name())
33             results = []
34             for n in range(100):
35                 f = self.doit(n, d.get_name())
36                 results.append(f)
37             smart_future.wait_all(results)
38             for f in results:
39                 self.assertTrue(f.wrapped_future.done())
40             for k in d:
41                 self.assertEqual(d[k], f'Hello from shard {k}')
42         finally:
43             d.close()
44             d.cleanup()
45
46     @p.parallelize(method=p.Method.PROCESS)
47     def add_one(self, name: str):
48         d = SharedDict(name)
49         try:
50             for x in range(1000):
51                 with SharedDict.MPLOCK:
52                     d["sum"] += 1
53         finally:
54             d.close()
55
56     def test_locking_works(self):
57         dict_name = 'test_shared_dict_lock'
58         d = SharedDict(dict_name, 4096)
59         try:
60             d["sum"] = 0
61             results = []
62             for n in range(10):
63                 f = self.add_one(d.get_name())
64                 results.append(f)
65             smart_future.wait_all(results)
66             self.assertEqual(10000, d["sum"])
67         finally:
68             d.close()
69             d.cleanup()
70
71
72 if __name__ == '__main__':
73     unittest.main()