from __future__ import annotations
import concurrent
import concurrent.futures as fut
+import logging
+import traceback
from typing import Callable, List, TypeVar
from overrides import overrides
from deferred_operand import DeferredOperand
import id_generator
+logger = logging.getLogger(__name__)
+
T = TypeVar('T')
-def wait_any(futures: List[SmartFuture], *, callback: Callable = None):
+def wait_any(
+ futures: List[SmartFuture],
+ *,
+ callback: Callable = None,
+ log_exceptions: bool = True,
+):
real_futures = []
smart_future_by_real_future = {}
completed_futures = set()
if callback is not None:
callback()
completed_futures.add(f)
+ if log_exceptions and not f.cancelled():
+ exception = f.exception()
+ if exception is not None:
+ logger.exception(exception)
+ traceback.print_tb(exception.__traceback__)
yield smart_future_by_real_future[f]
if callback is not None:
callback()
return
-def wait_all(futures: List[SmartFuture]) -> None:
+def wait_all(
+ futures: List[SmartFuture],
+ *,
+ log_exceptions: bool = True,
+) -> None:
real_futures = [x.wrapped_future for x in futures]
(done, not_done) = concurrent.futures.wait(
real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
)
+ if log_exceptions:
+ for f in real_futures:
+ if not f.cancelled():
+ exception = f.exception()
+ if exception is not None:
+ logger.exception(exception)
+ traceback.print_tb(exception.__traceback__)
assert len(done) == len(real_futures)
assert len(not_done) == 0