0b73a45bdabbb663873c9ac883e045a642fd3cc3
[python_utils.git] / tests / shared_dict_test.py
1 #!/usr/bin/env python3
2
3 import unittest
4
5 import parallelize as p
6 import smart_future
7 import unittest_utils
8 from collect.shared_dict import SharedDict
9
10
11 class SharedDictTest(unittest.TestCase):
12     @p.parallelize(method=p.Method.PROCESS)
13     def doit(self, n: int, dict_name: str):
14         d = SharedDict(dict_name)
15         try:
16             msg = f'Hello from shard {n}'
17             d[n] = msg
18             self.assertTrue(n in d)
19             self.assertEqual(msg, d[n])
20             return n
21         finally:
22             d.close()
23
24     def test_basic_operations(self):
25         dict_name = 'test_shared_dict'
26         d = SharedDict(dict_name, 4096)
27         try:
28             self.assertEqual(dict_name, d.get_name())
29             results = []
30             for n in range(100):
31                 f = self.doit(n, d.get_name())
32                 results.append(f)
33             smart_future.wait_all(results)
34             for f in results:
35                 self.assertTrue(f.wrapped_future.done())
36             for k in d:
37                 self.assertEqual(d[k], f'Hello from shard {k}')
38         finally:
39             d.close()
40             d.cleanup()
41
42     @p.parallelize(method=p.Method.PROCESS)
43     def add_one(self, name: str):
44         d = SharedDict(name)
45         try:
46             for x in range(1000):
47                 with SharedDict.MPLOCK:
48                     d["sum"] += 1
49         finally:
50             d.close()
51
52     def test_locking_works(self):
53         dict_name = 'test_shared_dict_lock'
54         d = SharedDict(dict_name, 4096)
55         try:
56             d["sum"] = 0
57             results = []
58             for n in range(10):
59                 f = self.add_one(d.get_name())
60                 results.append(f)
61             smart_future.wait_all(results)
62             self.assertEqual(10000, d["sum"])
63         finally:
64             d.close()
65             d.cleanup()
66
67
68 if __name__ == '__main__':
69     unittest.main()