Optionally surface exceptions that happen under executors by reading
[python_utils.git] / smart_future.py
index 7dbec5004b4ba927331e71fb812fd482af678c3c..c96c5a712d9f8d829eb73a21082f2ffbc0f41d48 100644 (file)
@@ -1,49 +1,71 @@
 #!/usr/bin/env python3
 
 from __future__ import annotations
-from collections.abc import Mapping
+import concurrent
 import concurrent.futures as fut
-import time
+import logging
+import traceback
 from typing import Callable, List, TypeVar
 
+from overrides import overrides
+
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 from deferred_operand import DeferredOperand
 import id_generator
 
+logger = logging.getLogger(__name__)
+
 T = TypeVar('T')
 
 
-def wait_any(futures: List[SmartFuture], *, callback: Callable = None):
-    finished: Mapping[int, bool] = {}
-    x = 0
-    while True:
-        future = futures[x]
-        if not finished.get(future.get_id(), False):
-            if future.is_ready():
-                finished[future.get_id()] = True
-                yield future
-            else:
-                if callback is not None:
-                    callback()
-                time.sleep(0.1)
-        x += 1
-        if x >= len(futures):
-            x = 0
-        if len(finished) == len(futures):
+def wait_any(
+    futures: List[SmartFuture],
+    *,
+    callback: Callable = None,
+    log_exceptions: bool = True,
+):
+    real_futures = []
+    smart_future_by_real_future = {}
+    completed_futures = set()
+    for _ in futures:
+        real_futures.append(_.wrapped_future)
+        smart_future_by_real_future[_.wrapped_future] = _
+    while len(completed_futures) != len(real_futures):
+        newly_completed_futures = concurrent.futures.as_completed(real_futures)
+        for f in newly_completed_futures:
             if callback is not None:
                 callback()
-            return
+            completed_futures.add(f)
+            if log_exceptions and not f.cancelled():
+                exception = f.exception()
+                if exception is not None:
+                    logger.exception(exception)
+                    traceback.print_tb(exception.__traceback__)
+            yield smart_future_by_real_future[f]
+    if callback is not None:
+        callback()
+    return
 
 
-def wait_all(futures: List[SmartFuture]) -> None:
-    done_set = set()
-    while len(done_set) < len(futures):
-        for future in futures:
-            i = future.get_id()
-            if i not in done_set and future.wrapped_future.done():
-                done_set.add(i)
-            time.sleep(0.1)
+def wait_all(
+    futures: List[SmartFuture],
+    *,
+    log_exceptions: bool = True,
+) -> None:
+    real_futures = [x.wrapped_future for x in futures]
+    (done, not_done) = concurrent.futures.wait(
+        real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
+    )
+    if log_exceptions:
+        for f in real_futures:
+            if not f.cancelled():
+                exception = f.exception()
+                if exception is not None:
+                    logger.exception(exception)
+                    traceback.print_tb(exception.__traceback__)
+    assert len(done) == len(real_futures)
+    assert len(not_done) == 0
 
 
 class SmartFuture(DeferredOperand):
@@ -66,5 +88,6 @@ class SmartFuture(DeferredOperand):
 
     # You shouldn't have to call this; instead, have a look at defining a
     # method on DeferredOperand base class.
+    @overrides
     def _resolve(self, *, timeout=None) -> T:
         return self.wrapped_future.result(timeout)