Tweak histograms in executors to include seconds label.
[python_utils.git] / smart_future.py
index c96c5a712d9f8d829eb73a21082f2ffbc0f41d48..604c149520464bcd9d8c5a55cf8905acd5ec34d4 100644 (file)
@@ -28,9 +28,11 @@ def wait_any(
     real_futures = []
     smart_future_by_real_future = {}
     completed_futures = set()
-    for _ in futures:
-        real_futures.append(_.wrapped_future)
-        smart_future_by_real_future[_.wrapped_future] = _
+    for f in futures:
+        assert type(f) == SmartFuture
+        real_futures.append(f.wrapped_future)
+        smart_future_by_real_future[f.wrapped_future] = f
+
     while len(completed_futures) != len(real_futures):
         newly_completed_futures = concurrent.futures.as_completed(real_futures)
         for f in newly_completed_futures:
@@ -40,8 +42,11 @@ def wait_any(
             if log_exceptions and not f.cancelled():
                 exception = f.exception()
                 if exception is not None:
+                    logger.warning(
+                        f'Future {id(f)} raised an unhandled exception and exited.'
+                    )
                     logger.exception(exception)
-                    traceback.print_tb(exception.__traceback__)
+                    raise exception
             yield smart_future_by_real_future[f]
     if callback is not None:
         callback()
@@ -53,7 +58,11 @@ def wait_all(
     *,
     log_exceptions: bool = True,
 ) -> None:
-    real_futures = [x.wrapped_future for x in futures]
+    real_futures = []
+    for f in futures:
+        assert type(f) == SmartFuture
+        real_futures.append(f.wrapped_future)
+
     (done, not_done) = concurrent.futures.wait(
         real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
     )
@@ -62,8 +71,11 @@ def wait_all(
             if not f.cancelled():
                 exception = f.exception()
                 if exception is not None:
+                    logger.warning(
+                        f'Future {id(f)} raised an unhandled exception and exited.'
+                    )
                     logger.exception(exception)
-                    traceback.print_tb(exception.__traceback__)
+                    raise exception
     assert len(done) == len(real_futures)
     assert len(not_done) == 0
 
@@ -77,6 +89,7 @@ class SmartFuture(DeferredOperand):
     """
 
     def __init__(self, wrapped_future: fut.Future) -> None:
+        assert type(wrapped_future) == fut.Future
         self.wrapped_future = wrapped_future
         self.id = id_generator.get("smart_future_id")