X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=smart_future.py;h=7aac8ebb68eb3bb3af6da50b127cfb344ff925e3;hb=532df2c5b57c7517dfb3dddd8c1358fbadf8baf3;hp=c097d53b7c849867db4dd7f81121097397d0bee3;hpb=ed8fa2b10b0177b15b7423263bdd390efde2f0c8;p=python_utils.git diff --git a/smart_future.py b/smart_future.py index c097d53..7aac8eb 100644 --- a/smart_future.py +++ b/smart_future.py @@ -1,46 +1,85 @@ #!/usr/bin/env python3 +# © Copyright 2021-2022, Scott Gasch + +"""A future that can be treated like the result that it contains and +will not block until it is used. At that point, if the underlying +value is not yet available, it will block until it becomes +available. + +""" + from __future__ import annotations import concurrent import concurrent.futures as fut -from typing import Callable, List, TypeVar +import logging +from typing import Callable, List, Set, TypeVar from overrides import overrides +import id_generator + # This module is commonly used by others in here and should avoid # taking any unnecessary dependencies back on them. from deferred_operand import DeferredOperand -import id_generator + +logger = logging.getLogger(__name__) T = TypeVar('T') -def wait_any(futures: List[SmartFuture], *, callback: Callable = None): +def wait_any( + futures: List[SmartFuture], + *, + callback: Callable = None, + log_exceptions: bool = True, +): real_futures = [] smart_future_by_real_future = {} - completed_futures = set() - for _ in futures: - real_futures.append(_.wrapped_future) - smart_future_by_real_future[_.wrapped_future] = _ + completed_futures: Set[fut.Future] = set() + for x in futures: + assert isinstance(x, SmartFuture) + real_futures.append(x.wrapped_future) + smart_future_by_real_future[x.wrapped_future] = x + while len(completed_futures) != len(real_futures): newly_completed_futures = concurrent.futures.as_completed(real_futures) for f in newly_completed_futures: if callback is not None: callback() completed_futures.add(f) + if log_exceptions and not f.cancelled(): + exception = f.exception() + if exception is not None: + logger.warning('Future 0x%x raised an unhandled exception and exited.', id(f)) + logger.exception(exception) + raise exception yield smart_future_by_real_future[f] if callback is not None: callback() - return -def wait_all(futures: List[SmartFuture]) -> None: - real_futures = [x.wrapped_future for x in futures] +def wait_all( + futures: List[SmartFuture], + *, + log_exceptions: bool = True, +) -> None: + real_futures = [] + for x in futures: + assert isinstance(x, SmartFuture) + real_futures.append(x.wrapped_future) + (done, not_done) = concurrent.futures.wait( - real_futures, - timeout=None, - return_when=concurrent.futures.ALL_COMPLETED + real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED ) + if log_exceptions: + for f in real_futures: + if not f.cancelled(): + exception = f.exception() + if exception is not None: + logger.warning('Future 0x%x raised an unhandled exception and exited.', id(f)) + logger.exception(exception) + raise exception assert len(done) == len(real_futures) assert len(not_done) == 0 @@ -54,6 +93,7 @@ class SmartFuture(DeferredOperand): """ def __init__(self, wrapped_future: fut.Future) -> None: + assert isinstance(wrapped_future, fut.Future) self.wrapped_future = wrapped_future self.id = id_generator.get("smart_future_id") @@ -66,5 +106,5 @@ class SmartFuture(DeferredOperand): # You shouldn't have to call this; instead, have a look at defining a # method on DeferredOperand base class. @overrides - def _resolve(self, *, timeout=None) -> T: + def _resolve(self, timeout=None) -> T: return self.wrapped_future.result(timeout)