X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=smart_future.py;h=1f6e6f0aedcf05966e536ec8f10f570c2175a3e4;hb=a4bf4d05230474ad14243d67ac7f8c938f670e58;hp=c96c5a712d9f8d829eb73a21082f2ffbc0f41d48;hpb=822454f580c1ff9eb207b8da46cdfae24e30cde1;p=python_utils.git diff --git a/smart_future.py b/smart_future.py index c96c5a7..1f6e6f0 100644 --- a/smart_future.py +++ b/smart_future.py @@ -5,7 +5,7 @@ import concurrent import concurrent.futures as fut import logging import traceback -from typing import Callable, List, TypeVar +from typing import Callable, List, Set, TypeVar from overrides import overrides @@ -27,10 +27,12 @@ def wait_any( ): real_futures = [] smart_future_by_real_future = {} - completed_futures = set() - for _ in futures: - real_futures.append(_.wrapped_future) - smart_future_by_real_future[_.wrapped_future] = _ + completed_futures: Set[fut.Future] = set() + for x in futures: + assert type(x) == SmartFuture + real_futures.append(x.wrapped_future) + smart_future_by_real_future[x.wrapped_future] = x + while len(completed_futures) != len(real_futures): newly_completed_futures = concurrent.futures.as_completed(real_futures) for f in newly_completed_futures: @@ -40,8 +42,11 @@ def wait_any( if log_exceptions and not f.cancelled(): exception = f.exception() if exception is not None: + logger.warning( + f'Future {id(f)} raised an unhandled exception and exited.' + ) logger.exception(exception) - traceback.print_tb(exception.__traceback__) + raise exception yield smart_future_by_real_future[f] if callback is not None: callback() @@ -53,7 +58,11 @@ def wait_all( *, log_exceptions: bool = True, ) -> None: - real_futures = [x.wrapped_future for x in futures] + real_futures = [] + for x in futures: + assert type(x) == SmartFuture + real_futures.append(x.wrapped_future) + (done, not_done) = concurrent.futures.wait( real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED ) @@ -62,8 +71,11 @@ def wait_all( if not f.cancelled(): exception = f.exception() if exception is not None: + logger.warning( + f'Future {id(f)} raised an unhandled exception and exited.' + ) logger.exception(exception) - traceback.print_tb(exception.__traceback__) + raise exception assert len(done) == len(real_futures) assert len(not_done) == 0 @@ -77,6 +89,7 @@ class SmartFuture(DeferredOperand): """ def __init__(self, wrapped_future: fut.Future) -> None: + assert type(wrapped_future) == fut.Future self.wrapped_future = wrapped_future self.id = id_generator.get("smart_future_id")