From f2b4fe83f6fc853a68653bd5e3d9fe0648c3d105 Mon Sep 17 00:00:00 2001 From: Scott Date: Sun, 30 Jan 2022 22:09:43 -0800 Subject: [PATCH] Fix a recent bug in executors. Thread executor needs to return its future. --- executors.py | 1 + smart_future.py | 15 +++++++++++---- tests/parallelize_itest.py | 3 ++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/executors.py b/executors.py index 5b77a42..47b4a89 100644 --- a/executors.py +++ b/executors.py @@ -152,6 +152,7 @@ class ThreadExecutor(BaseExecutor): ) result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start)) result.add_done_callback(lambda _: self.adjust_task_count(-1)) + return result @overrides def shutdown(self, wait=True) -> None: diff --git a/smart_future.py b/smart_future.py index 2f3cbd9..604c149 100644 --- a/smart_future.py +++ b/smart_future.py @@ -28,9 +28,11 @@ def wait_any( real_futures = [] smart_future_by_real_future = {} completed_futures = set() - for _ in futures: - real_futures.append(_.wrapped_future) - smart_future_by_real_future[_.wrapped_future] = _ + for f in futures: + assert type(f) == SmartFuture + real_futures.append(f.wrapped_future) + smart_future_by_real_future[f.wrapped_future] = f + while len(completed_futures) != len(real_futures): newly_completed_futures = concurrent.futures.as_completed(real_futures) for f in newly_completed_futures: @@ -56,7 +58,11 @@ def wait_all( *, log_exceptions: bool = True, ) -> None: - real_futures = [x.wrapped_future for x in futures] + real_futures = [] + for f in futures: + assert type(f) == SmartFuture + real_futures.append(f.wrapped_future) + (done, not_done) = concurrent.futures.wait( real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED ) @@ -83,6 +89,7 @@ class SmartFuture(DeferredOperand): """ def __init__(self, wrapped_future: fut.Future) -> None: + assert type(wrapped_future) == fut.Future self.wrapped_future = wrapped_future self.id = id_generator.get("smart_future_id") diff --git a/tests/parallelize_itest.py b/tests/parallelize_itest.py index 11c5676..6ac9538 100755 --- a/tests/parallelize_itest.py +++ b/tests/parallelize_itest.py @@ -37,7 +37,8 @@ def compute_factorial_remote(n): def test_thread_parallelization() -> None: results = [] for _ in range(50): - results.append(compute_factorial_thread(_)) + f = compute_factorial_thread(_) + results.append(f) smart_future.wait_all(results) for future in results: print(f'Thread: {future}') -- 2.46.0