More cleanup.
[python_utils.git] / smart_future.py
1 #!/usr/bin/env python3
2
3 """A future that can be treated like the result that it contains and
4 will not block until it is used.  At that point, if the underlying
5 value is not yet available, it will block until it becomes
6 available."""
7
8 from __future__ import annotations
9 import concurrent
10 import concurrent.futures as fut
11 import logging
12 from typing import Callable, List, Set, TypeVar
13
14 from overrides import overrides
15
16 import id_generator
17
18 # This module is commonly used by others in here and should avoid
19 # taking any unnecessary dependencies back on them.
20 from deferred_operand import DeferredOperand
21
22 logger = logging.getLogger(__name__)
23
24 T = TypeVar('T')
25
26
27 def wait_any(
28     futures: List[SmartFuture],
29     *,
30     callback: Callable = None,
31     log_exceptions: bool = True,
32 ):
33     real_futures = []
34     smart_future_by_real_future = {}
35     completed_futures: Set[fut.Future] = set()
36     for x in futures:
37         assert isinstance(x, SmartFuture)
38         real_futures.append(x.wrapped_future)
39         smart_future_by_real_future[x.wrapped_future] = x
40
41     while len(completed_futures) != len(real_futures):
42         newly_completed_futures = concurrent.futures.as_completed(real_futures)
43         for f in newly_completed_futures:
44             if callback is not None:
45                 callback()
46             completed_futures.add(f)
47             if log_exceptions and not f.cancelled():
48                 exception = f.exception()
49                 if exception is not None:
50                     logger.warning('Future 0x%x raised an unhandled exception and exited.', id(f))
51                     logger.exception(exception)
52                     raise exception
53             yield smart_future_by_real_future[f]
54     if callback is not None:
55         callback()
56
57
58 def wait_all(
59     futures: List[SmartFuture],
60     *,
61     log_exceptions: bool = True,
62 ) -> None:
63     real_futures = []
64     for x in futures:
65         assert isinstance(x, SmartFuture)
66         real_futures.append(x.wrapped_future)
67
68     (done, not_done) = concurrent.futures.wait(
69         real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
70     )
71     if log_exceptions:
72         for f in real_futures:
73             if not f.cancelled():
74                 exception = f.exception()
75                 if exception is not None:
76                     logger.warning('Future 0x%x raised an unhandled exception and exited.', id(f))
77                     logger.exception(exception)
78                     raise exception
79     assert len(done) == len(real_futures)
80     assert len(not_done) == 0
81
82
83 class SmartFuture(DeferredOperand):
84     """This is a SmartFuture, a class that wraps a normal Future and can
85     then be used, mostly, like a normal (non-Future) identifier.
86
87     Using a FutureWrapper in expressions will block and wait until
88     the result of the deferred operation is known.
89     """
90
91     def __init__(self, wrapped_future: fut.Future) -> None:
92         assert isinstance(wrapped_future, fut.Future)
93         self.wrapped_future = wrapped_future
94         self.id = id_generator.get("smart_future_id")
95
96     def get_id(self) -> int:
97         return self.id
98
99     def is_ready(self) -> bool:
100         return self.wrapped_future.done()
101
102     # You shouldn't have to call this; instead, have a look at defining a
103     # method on DeferredOperand base class.
104     @overrides
105     def _resolve(self, timeout=None) -> T:
106         return self.wrapped_future.result(timeout)