9aa68f38a5396953636e257032542f9d07533225
[python_utils.git] / smart_future.py
1 #!/usr/bin/env python3
2
3 """A future that can be treated like the result that it contains and
4 will not block until it is used.  At that point, if the underlying
5 value is not yet available, it will block until it becomes
6 available."""
7
8 from __future__ import annotations
9
10 import concurrent
11 import concurrent.futures as fut
12 import logging
13 from typing import Callable, List, Set, TypeVar
14
15 from overrides import overrides
16
17 import id_generator
18
19 # This module is commonly used by others in here and should avoid
20 # taking any unnecessary dependencies back on them.
21 from deferred_operand import DeferredOperand
22
23 logger = logging.getLogger(__name__)
24
25 T = TypeVar('T')
26
27
28 def wait_any(
29     futures: List[SmartFuture],
30     *,
31     callback: Callable = None,
32     log_exceptions: bool = True,
33 ):
34     real_futures = []
35     smart_future_by_real_future = {}
36     completed_futures: Set[fut.Future] = set()
37     for x in futures:
38         assert isinstance(x, SmartFuture)
39         real_futures.append(x.wrapped_future)
40         smart_future_by_real_future[x.wrapped_future] = x
41
42     while len(completed_futures) != len(real_futures):
43         newly_completed_futures = concurrent.futures.as_completed(real_futures)
44         for f in newly_completed_futures:
45             if callback is not None:
46                 callback()
47             completed_futures.add(f)
48             if log_exceptions and not f.cancelled():
49                 exception = f.exception()
50                 if exception is not None:
51                     logger.warning('Future 0x%x raised an unhandled exception and exited.', id(f))
52                     logger.exception(exception)
53                     raise exception
54             yield smart_future_by_real_future[f]
55     if callback is not None:
56         callback()
57
58
59 def wait_all(
60     futures: List[SmartFuture],
61     *,
62     log_exceptions: bool = True,
63 ) -> None:
64     real_futures = []
65     for x in futures:
66         assert isinstance(x, SmartFuture)
67         real_futures.append(x.wrapped_future)
68
69     (done, not_done) = concurrent.futures.wait(
70         real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
71     )
72     if log_exceptions:
73         for f in real_futures:
74             if not f.cancelled():
75                 exception = f.exception()
76                 if exception is not None:
77                     logger.warning('Future 0x%x raised an unhandled exception and exited.', id(f))
78                     logger.exception(exception)
79                     raise exception
80     assert len(done) == len(real_futures)
81     assert len(not_done) == 0
82
83
84 class SmartFuture(DeferredOperand):
85     """This is a SmartFuture, a class that wraps a normal Future and can
86     then be used, mostly, like a normal (non-Future) identifier.
87
88     Using a FutureWrapper in expressions will block and wait until
89     the result of the deferred operation is known.
90     """
91
92     def __init__(self, wrapped_future: fut.Future) -> None:
93         assert isinstance(wrapped_future, fut.Future)
94         self.wrapped_future = wrapped_future
95         self.id = id_generator.get("smart_future_id")
96
97     def get_id(self) -> int:
98         return self.id
99
100     def is_ready(self) -> bool:
101         return self.wrapped_future.done()
102
103     # You shouldn't have to call this; instead, have a look at defining a
104     # method on DeferredOperand base class.
105     @overrides
106     def _resolve(self, timeout=None) -> T:
107         return self.wrapped_future.result(timeout)