5 from collect.shared_dict import SharedDict
6 import parallelize as p
11 class SharedDictTest(unittest.TestCase):
12 @p.parallelize(method=p.Method.PROCESS)
13 def doit(self, n: int, dict_name: str):
14 d = SharedDict(dict_name)
16 msg = f'Hello from shard {n}'
18 self.assertTrue(n in d)
19 self.assertEqual(msg, d[n])
24 def test_basic_operations(self):
25 dict_name = 'test_shared_dict'
26 d = SharedDict(dict_name, 4096)
28 self.assertEqual(dict_name, d.get_name())
31 f = self.doit(n, d.get_name())
33 smart_future.wait_all(results)
35 self.assertTrue(f.wrapped_future.done())
37 self.assertEqual(d[k], f'Hello from shard {k}')
42 @p.parallelize(method=p.Method.PROCESS)
43 def add_one(self, name: str):
47 with SharedDict.MPLOCK:
52 def test_locking_works(self):
53 dict_name = 'test_shared_dict_lock'
54 d = SharedDict(dict_name, 4096)
59 f = self.add_one(d.get_name())
61 smart_future.wait_all(results)
62 self.assertEqual(10000, d["sum"])
68 if __name__ == '__main__':