Optionally surface exceptions that happen under executors by reading
[python_utils.git] / smart_future.py
index 8f23e7756b5417b15d89ba35c43dde971678ee14..c96c5a712d9f8d829eb73a21082f2ffbc0f41d48 100644 (file)
@@ -3,6 +3,8 @@
 from __future__ import annotations
 import concurrent
 import concurrent.futures as fut
+import logging
+import traceback
 from typing import Callable, List, TypeVar
 
 from overrides import overrides
@@ -12,10 +14,17 @@ from overrides import overrides
 from deferred_operand import DeferredOperand
 import id_generator
 
+logger = logging.getLogger(__name__)
+
 T = TypeVar('T')
 
 
-def wait_any(futures: List[SmartFuture], *, callback: Callable = None):
+def wait_any(
+    futures: List[SmartFuture],
+    *,
+    callback: Callable = None,
+    log_exceptions: bool = True,
+):
     real_futures = []
     smart_future_by_real_future = {}
     completed_futures = set()
@@ -28,17 +37,33 @@ def wait_any(futures: List[SmartFuture], *, callback: Callable = None):
             if callback is not None:
                 callback()
             completed_futures.add(f)
+            if log_exceptions and not f.cancelled():
+                exception = f.exception()
+                if exception is not None:
+                    logger.exception(exception)
+                    traceback.print_tb(exception.__traceback__)
             yield smart_future_by_real_future[f]
     if callback is not None:
         callback()
     return
 
 
-def wait_all(futures: List[SmartFuture]) -> None:
+def wait_all(
+    futures: List[SmartFuture],
+    *,
+    log_exceptions: bool = True,
+) -> None:
     real_futures = [x.wrapped_future for x in futures]
     (done, not_done) = concurrent.futures.wait(
         real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
     )
+    if log_exceptions:
+        for f in real_futures:
+            if not f.cancelled():
+                exception = f.exception()
+                if exception is not None:
+                    logger.exception(exception)
+                    traceback.print_tb(exception.__traceback__)
     assert len(done) == len(real_futures)
     assert len(not_done) == 0