Used isort to sort imports. Also added to the git pre-commit hook.
[python_utils.git] / smart_future.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 import concurrent
6 import concurrent.futures as fut
7 import logging
8 import traceback
9 from typing import Callable, List, Set, TypeVar
10
11 from overrides import overrides
12
13 import id_generator
14
15 # This module is commonly used by others in here and should avoid
16 # taking any unnecessary dependencies back on them.
17 from deferred_operand import DeferredOperand
18
19 logger = logging.getLogger(__name__)
20
21 T = TypeVar('T')
22
23
24 def wait_any(
25     futures: List[SmartFuture],
26     *,
27     callback: Callable = None,
28     log_exceptions: bool = True,
29 ):
30     real_futures = []
31     smart_future_by_real_future = {}
32     completed_futures: Set[fut.Future] = set()
33     for x in futures:
34         assert type(x) == SmartFuture
35         real_futures.append(x.wrapped_future)
36         smart_future_by_real_future[x.wrapped_future] = x
37
38     while len(completed_futures) != len(real_futures):
39         newly_completed_futures = concurrent.futures.as_completed(real_futures)
40         for f in newly_completed_futures:
41             if callback is not None:
42                 callback()
43             completed_futures.add(f)
44             if log_exceptions and not f.cancelled():
45                 exception = f.exception()
46                 if exception is not None:
47                     logger.warning(
48                         f'Future {id(f)} raised an unhandled exception and exited.'
49                     )
50                     logger.exception(exception)
51                     raise exception
52             yield smart_future_by_real_future[f]
53     if callback is not None:
54         callback()
55     return
56
57
58 def wait_all(
59     futures: List[SmartFuture],
60     *,
61     log_exceptions: bool = True,
62 ) -> None:
63     real_futures = []
64     for x in futures:
65         assert type(x) == SmartFuture
66         real_futures.append(x.wrapped_future)
67
68     (done, not_done) = concurrent.futures.wait(
69         real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
70     )
71     if log_exceptions:
72         for f in real_futures:
73             if not f.cancelled():
74                 exception = f.exception()
75                 if exception is not None:
76                     logger.warning(
77                         f'Future {id(f)} raised an unhandled exception and exited.'
78                     )
79                     logger.exception(exception)
80                     raise exception
81     assert len(done) == len(real_futures)
82     assert len(not_done) == 0
83
84
85 class SmartFuture(DeferredOperand):
86     """This is a SmartFuture, a class that wraps a normal Future and can
87     then be used, mostly, like a normal (non-Future) identifier.
88
89     Using a FutureWrapper in expressions will block and wait until
90     the result of the deferred operation is known.
91     """
92
93     def __init__(self, wrapped_future: fut.Future) -> None:
94         assert type(wrapped_future) == fut.Future
95         self.wrapped_future = wrapped_future
96         self.id = id_generator.get("smart_future_id")
97
98     def get_id(self) -> int:
99         return self.id
100
101     def is_ready(self) -> bool:
102         return self.wrapped_future.done()
103
104     # You shouldn't have to call this; instead, have a look at defining a
105     # method on DeferredOperand base class.
106     @overrides
107     def _resolve(self, *, timeout=None) -> T:
108         return self.wrapped_future.result(timeout)