X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=smart_future.py;h=8f23e7756b5417b15d89ba35c43dde971678ee14;hb=36fea7f15ed17150691b5b3ead75450e575229ef;hp=f1ffee1c63250b4fb0d0be319a8090ce406f5fc0;hpb=497fb9e21f45ec08e1486abaee6dfa7b20b8a691;p=python_utils.git diff --git a/smart_future.py b/smart_future.py index f1ffee1..8f23e77 100644 --- a/smart_future.py +++ b/smart_future.py @@ -1,37 +1,46 @@ #!/usr/bin/env python3 from __future__ import annotations -from collections.abc import Mapping +import concurrent import concurrent.futures as fut -import time from typing import Callable, List, TypeVar +from overrides import overrides + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. from deferred_operand import DeferredOperand import id_generator T = TypeVar('T') -def wait_many(futures: List[SmartFuture], *, callback: Callable = None): - finished: Mapping[int, bool] = {} - x = 0 - while True: - future = futures[x] - if not finished.get(future.get_id(), False): - if future.is_ready(): - finished[future.get_id()] = True - yield future - else: - if callback is not None: - callback() - time.sleep(0.1) - x += 1 - if x >= len(futures): - x = 0 - if len(finished) == len(futures): +def wait_any(futures: List[SmartFuture], *, callback: Callable = None): + real_futures = [] + smart_future_by_real_future = {} + completed_futures = set() + for _ in futures: + real_futures.append(_.wrapped_future) + smart_future_by_real_future[_.wrapped_future] = _ + while len(completed_futures) != len(real_futures): + newly_completed_futures = concurrent.futures.as_completed(real_futures) + for f in newly_completed_futures: if callback is not None: callback() - return + completed_futures.add(f) + yield smart_future_by_real_future[f] + if callback is not None: + callback() + return + + +def wait_all(futures: List[SmartFuture]) -> None: + real_futures = [x.wrapped_future for x in futures] + (done, not_done) = concurrent.futures.wait( + real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED + ) + assert len(done) == len(real_futures) + assert len(not_done) == 0 class SmartFuture(DeferredOperand): @@ -54,5 +63,6 @@ class SmartFuture(DeferredOperand): # You shouldn't have to call this; instead, have a look at defining a # method on DeferredOperand base class. + @overrides def _resolve(self, *, timeout=None) -> T: return self.wrapped_future.result(timeout)