-def wait_all(futures: List[SmartFuture]) -> None:
- done_set = set()
- while len(done_set) < len(futures):
- for future in futures:
- i = future.get_id()
- if i not in done_set and future.wrapped_future.done():
- done_set.add(i)
- time.sleep(0.1)
+def wait_all(
+ futures: List[SmartFuture],
+ *,
+ log_exceptions: bool = True,
+) -> None:
+ real_futures = [x.wrapped_future for x in futures]
+ (done, not_done) = concurrent.futures.wait(
+ real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
+ )
+ if log_exceptions:
+ for f in real_futures:
+ if not f.cancelled():
+ exception = f.exception()
+ if exception is not None:
+ logger.exception(exception)
+ traceback.print_tb(exception.__traceback__)
+ assert len(done) == len(real_futures)
+ assert len(not_done) == 0