X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=smart_future.py;h=460dcb95862a9a88ecafde45fa8886dbcbb9eaa1;hb=31c81f6539969a5eba864d3305f9fb7bf716a367;hp=f11be17bc0fabfeb4a111e4b3356bbcb1ed1633a;hpb=eb9e6df32ed696158bf34dba6464277b648f5c74;p=python_utils.git diff --git a/smart_future.py b/smart_future.py index f11be17..460dcb9 100644 --- a/smart_future.py +++ b/smart_future.py @@ -1,51 +1,85 @@ #!/usr/bin/env python3 from __future__ import annotations -from collections.abc import Mapping + +import concurrent import concurrent.futures as fut -import time -from typing import Callable, List, TypeVar +import logging +import traceback +from typing import Callable, List, Set, TypeVar from overrides import overrides +import id_generator + # This module is commonly used by others in here and should avoid # taking any unnecessary dependencies back on them. from deferred_operand import DeferredOperand -import id_generator + +logger = logging.getLogger(__name__) T = TypeVar('T') -def wait_any(futures: List[SmartFuture], *, callback: Callable = None): - finished: Mapping[int, bool] = {} - x = 0 - while True: - future = futures[x] - if not finished.get(future.get_id(), False): - if future.is_ready(): - finished[future.get_id()] = True - yield future - else: - if callback is not None: - callback() - time.sleep(0.1) - x += 1 - if x >= len(futures): - x = 0 - if len(finished) == len(futures): +def wait_any( + futures: List[SmartFuture], + *, + callback: Callable = None, + log_exceptions: bool = True, +): + real_futures = [] + smart_future_by_real_future = {} + completed_futures: Set[fut.Future] = set() + for x in futures: + assert type(x) == SmartFuture + real_futures.append(x.wrapped_future) + smart_future_by_real_future[x.wrapped_future] = x + + while len(completed_futures) != len(real_futures): + newly_completed_futures = concurrent.futures.as_completed(real_futures) + for f in newly_completed_futures: if callback is not None: callback() - return + completed_futures.add(f) + if log_exceptions and not f.cancelled(): + exception = f.exception() + if exception is not None: + logger.warning( + f'Future {id(f)} raised an unhandled exception and exited.' + ) + logger.exception(exception) + raise exception + yield smart_future_by_real_future[f] + if callback is not None: + callback() + return + +def wait_all( + futures: List[SmartFuture], + *, + log_exceptions: bool = True, +) -> None: + real_futures = [] + for x in futures: + assert type(x) == SmartFuture + real_futures.append(x.wrapped_future) -def wait_all(futures: List[SmartFuture]) -> None: - done_set = set() - while len(done_set) < len(futures): - for future in futures: - i = future.get_id() - if i not in done_set and future.wrapped_future.done(): - done_set.add(i) - time.sleep(0.1) + (done, not_done) = concurrent.futures.wait( + real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED + ) + if log_exceptions: + for f in real_futures: + if not f.cancelled(): + exception = f.exception() + if exception is not None: + logger.warning( + f'Future {id(f)} raised an unhandled exception and exited.' + ) + logger.exception(exception) + raise exception + assert len(done) == len(real_futures) + assert len(not_done) == 0 class SmartFuture(DeferredOperand): @@ -57,6 +91,7 @@ class SmartFuture(DeferredOperand): """ def __init__(self, wrapped_future: fut.Future) -> None: + assert type(wrapped_future) == fut.Future self.wrapped_future = wrapped_future self.id = id_generator.get("smart_future_id")