Change locking boundaries for shared dict. Add a unit test.
[python_utils.git] / smart_future.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4 import concurrent
5 import concurrent.futures as fut
6 import logging
7 import traceback
8 from typing import Callable, List, TypeVar
9
10 from overrides import overrides
11
12 # This module is commonly used by others in here and should avoid
13 # taking any unnecessary dependencies back on them.
14 from deferred_operand import DeferredOperand
15 import id_generator
16
17 logger = logging.getLogger(__name__)
18
19 T = TypeVar('T')
20
21
22 def wait_any(
23     futures: List[SmartFuture],
24     *,
25     callback: Callable = None,
26     log_exceptions: bool = True,
27 ):
28     real_futures = []
29     smart_future_by_real_future = {}
30     completed_futures = set()
31     for _ in futures:
32         real_futures.append(_.wrapped_future)
33         smart_future_by_real_future[_.wrapped_future] = _
34     while len(completed_futures) != len(real_futures):
35         newly_completed_futures = concurrent.futures.as_completed(real_futures)
36         for f in newly_completed_futures:
37             if callback is not None:
38                 callback()
39             completed_futures.add(f)
40             if log_exceptions and not f.cancelled():
41                 exception = f.exception()
42                 if exception is not None:
43                     logger.warning(
44                         f'Future {id(f)} raised an unhandled exception and exited.'
45                     )
46                     logger.exception(exception)
47                     raise exception
48             yield smart_future_by_real_future[f]
49     if callback is not None:
50         callback()
51     return
52
53
54 def wait_all(
55     futures: List[SmartFuture],
56     *,
57     log_exceptions: bool = True,
58 ) -> None:
59     real_futures = [x.wrapped_future for x in futures]
60     (done, not_done) = concurrent.futures.wait(
61         real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
62     )
63     if log_exceptions:
64         for f in real_futures:
65             if not f.cancelled():
66                 exception = f.exception()
67                 if exception is not None:
68                     logger.warning(
69                         f'Future {id(f)} raised an unhandled exception and exited.'
70                     )
71                     logger.exception(exception)
72                     raise exception
73     assert len(done) == len(real_futures)
74     assert len(not_done) == 0
75
76
77 class SmartFuture(DeferredOperand):
78     """This is a SmartFuture, a class that wraps a normal Future and can
79     then be used, mostly, like a normal (non-Future) identifier.
80
81     Using a FutureWrapper in expressions will block and wait until
82     the result of the deferred operation is known.
83     """
84
85     def __init__(self, wrapped_future: fut.Future) -> None:
86         self.wrapped_future = wrapped_future
87         self.id = id_generator.get("smart_future_id")
88
89     def get_id(self) -> int:
90         return self.id
91
92     def is_ready(self) -> bool:
93         return self.wrapped_future.done()
94
95     # You shouldn't have to call this; instead, have a look at defining a
96     # method on DeferredOperand base class.
97     @overrides
98     def _resolve(self, *, timeout=None) -> T:
99         return self.wrapped_future.result(timeout)