#!/usr/bin/env python3
from __future__ import annotations
-from collections.abc import Mapping
+import concurrent
import concurrent.futures as fut
-import time
from typing import Callable, List, TypeVar
from overrides import overrides
def wait_any(futures: List[SmartFuture], *, callback: Callable = None):
- finished: Mapping[int, bool] = {}
- x = 0
- while True:
- future = futures[x]
- if not finished.get(future.get_id(), False):
- if future.is_ready():
- finished[future.get_id()] = True
- yield future
- else:
- if callback is not None:
- callback()
- time.sleep(0.1)
- x += 1
- if x >= len(futures):
- x = 0
- if len(finished) == len(futures):
+ real_futures = []
+ smart_future_by_real_future = {}
+ completed_futures = set()
+ for _ in futures:
+ real_futures.append(_.wrapped_future)
+ smart_future_by_real_future[_.wrapped_future] = _
+ while len(completed_futures) != len(real_futures):
+ newly_completed_futures = concurrent.futures.as_completed(real_futures)
+ for f in newly_completed_futures:
if callback is not None:
callback()
- return
+ completed_futures.add(f)
+ yield smart_future_by_real_future[f]
+ if callback is not None:
+ callback()
+ return
def wait_all(futures: List[SmartFuture]) -> None:
- done_set = set()
- while len(done_set) < len(futures):
- for future in futures:
- i = future.get_id()
- if i not in done_set and future.wrapped_future.done():
- done_set.add(i)
- time.sleep(0.1)
+ real_futures = [x.wrapped_future for x in futures]
+ (done, not_done) = concurrent.futures.wait(
+ real_futures,
+ timeout=None,
+ return_when=concurrent.futures.ALL_COMPLETED
+ )
+ assert len(done) == len(real_futures)
+ assert len(not_done) == 0
class SmartFuture(DeferredOperand):