Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / smart_future.py
index c097d53b7c849867db4dd7f81121097397d0bee3..7aac8ebb68eb3bb3af6da50b127cfb344ff925e3 100644 (file)
@@ -1,46 +1,85 @@
 #!/usr/bin/env python3
 
+# © Copyright 2021-2022, Scott Gasch
+
+"""A future that can be treated like the result that it contains and
+will not block until it is used.  At that point, if the underlying
+value is not yet available, it will block until it becomes
+available.
+
+"""
+
 from __future__ import annotations
 import concurrent
 import concurrent.futures as fut
-from typing import Callable, List, TypeVar
+import logging
+from typing import Callable, List, Set, TypeVar
 
 from overrides import overrides
 
+import id_generator
+
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 from deferred_operand import DeferredOperand
-import id_generator
+
+logger = logging.getLogger(__name__)
 
 T = TypeVar('T')
 
 
-def wait_any(futures: List[SmartFuture], *, callback: Callable = None):
+def wait_any(
+    futures: List[SmartFuture],
+    *,
+    callback: Callable = None,
+    log_exceptions: bool = True,
+):
     real_futures = []
     smart_future_by_real_future = {}
-    completed_futures = set()
-    for _ in futures:
-        real_futures.append(_.wrapped_future)
-        smart_future_by_real_future[_.wrapped_future] = _
+    completed_futures: Set[fut.Future] = set()
+    for x in futures:
+        assert isinstance(x, SmartFuture)
+        real_futures.append(x.wrapped_future)
+        smart_future_by_real_future[x.wrapped_future] = x
+
     while len(completed_futures) != len(real_futures):
         newly_completed_futures = concurrent.futures.as_completed(real_futures)
         for f in newly_completed_futures:
             if callback is not None:
                 callback()
             completed_futures.add(f)
+            if log_exceptions and not f.cancelled():
+                exception = f.exception()
+                if exception is not None:
+                    logger.warning('Future 0x%x raised an unhandled exception and exited.', id(f))
+                    logger.exception(exception)
+                    raise exception
             yield smart_future_by_real_future[f]
     if callback is not None:
         callback()
-    return
 
 
-def wait_all(futures: List[SmartFuture]) -> None:
-    real_futures = [x.wrapped_future for x in futures]
+def wait_all(
+    futures: List[SmartFuture],
+    *,
+    log_exceptions: bool = True,
+) -> None:
+    real_futures = []
+    for x in futures:
+        assert isinstance(x, SmartFuture)
+        real_futures.append(x.wrapped_future)
+
     (done, not_done) = concurrent.futures.wait(
-        real_futures,
-        timeout=None,
-        return_when=concurrent.futures.ALL_COMPLETED
+        real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
     )
+    if log_exceptions:
+        for f in real_futures:
+            if not f.cancelled():
+                exception = f.exception()
+                if exception is not None:
+                    logger.warning('Future 0x%x raised an unhandled exception and exited.', id(f))
+                    logger.exception(exception)
+                    raise exception
     assert len(done) == len(real_futures)
     assert len(not_done) == 0
 
@@ -54,6 +93,7 @@ class SmartFuture(DeferredOperand):
     """
 
     def __init__(self, wrapped_future: fut.Future) -> None:
+        assert isinstance(wrapped_future, fut.Future)
         self.wrapped_future = wrapped_future
         self.id = id_generator.get("smart_future_id")
 
@@ -66,5 +106,5 @@ class SmartFuture(DeferredOperand):
     # You shouldn't have to call this; instead, have a look at defining a
     # method on DeferredOperand base class.
     @overrides
-    def _resolve(self, *, timeout=None) -> T:
+    def _resolve(self, timeout=None) -> T:
         return self.wrapped_future.result(timeout)