Optionally surface exceptions that happen under executors by reading
[python_utils.git] / executors.py
index 3786954a418c257bd2c7d0ae81d1455c0d8ea1cc..1df1877582362bdc9971d64c94df44f6d3e943cd 100644 (file)
@@ -74,10 +74,10 @@ def make_cloud_pickle(fun, *args, **kwargs):
 class BaseExecutor(ABC):
     def __init__(self, *, title=''):
         self.title = title
-        self.task_count = 0
         self.histogram = hist.SimpleHistogram(
             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
         )
+        self.task_count = 0
 
     @abstractmethod
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
@@ -88,9 +88,23 @@ class BaseExecutor(ABC):
         pass
 
     def adjust_task_count(self, delta: int) -> None:
+        """Change the task count.  Note: do not call this method from a
+        worker, it should only be called by the launcher process /
+        thread / machine.
+
+        """
         self.task_count += delta
+        print(f'(adjusted task count by {delta} to {self.task_count})')
         logger.debug(f'Executor current task count is {self.task_count}')
 
+    def get_task_count(self) -> int:
+        """Change the task count.  Note: do not call this method from a
+        worker, it should only be called by the launcher process /
+        thread / machine.
+
+        """
+        return self.task_count
+
 
 class ThreadExecutor(BaseExecutor):
     def __init__(self, max_workers: Optional[int] = None):
@@ -105,15 +119,10 @@ class ThreadExecutor(BaseExecutor):
             max_workers=workers, thread_name_prefix="thread_executor_helper"
         )
 
+    # This is run on a different thread; do not adjust task count here.
     def run_local_bundle(self, fun, *args, **kwargs):
         logger.debug(f"Running local bundle at {fun.__name__}")
-        start = time.time()
         result = fun(*args, **kwargs)
-        end = time.time()
-        self.adjust_task_count(-1)
-        duration = end - start
-        logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
-        self.histogram.add_item(duration)
         return result
 
     @overrides
@@ -123,9 +132,12 @@ class ThreadExecutor(BaseExecutor):
         newargs.append(function)
         for arg in args:
             newargs.append(arg)
-        return self._thread_pool_executor.submit(
+        start = time.time()
+        result = self._thread_pool_executor.submit(
             self.run_local_bundle, *newargs, **kwargs
         )
+        result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
+        result.add_done_callback(lambda _: self.adjust_task_count(-1))
 
     @overrides
     def shutdown(self, wait=True) -> None:
@@ -147,11 +159,11 @@ class ProcessExecutor(BaseExecutor):
             max_workers=workers,
         )
 
+    # This is run in another process; do not adjust task count here.
     def run_cloud_pickle(self, pickle):
         fun, args, kwargs = cloudpickle.loads(pickle)
         logger.debug(f"Running pickled bundle at {fun.__name__}")
         result = fun(*args, **kwargs)
-        self.adjust_task_count(-1)
         return result
 
     @overrides
@@ -161,6 +173,7 @@ class ProcessExecutor(BaseExecutor):
         pickle = make_cloud_pickle(function, *args, **kwargs)
         result = self._process_executor.submit(self.run_cloud_pickle, pickle)
         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
+        result.add_done_callback(lambda _: self.adjust_task_count(-1))
         return result
 
     @overrides
@@ -428,21 +441,29 @@ class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
         grabbag = []
         for worker in self.workers:
-            for x in range(0, worker.count):
-                for y in range(0, worker.weight):
-                    grabbag.append(worker)
-
-        for _ in range(0, 5):
-            random.shuffle(grabbag)
-            worker = grabbag[0]
-            if worker.machine != machine_to_avoid or _ > 2:
+            if worker.machine != machine_to_avoid:
                 if worker.count > 0:
-                    worker.count -= 1
-                    logger.debug(f'Selected worker {worker}')
-                    return worker
-        msg = 'Unexpectedly could not find a worker, retrying...'
-        logger.warning(msg)
-        return None
+                    for _ in range(worker.count * worker.weight):
+                        grabbag.append(worker)
+
+        if len(grabbag) == 0:
+            logger.debug(
+                f'There are no available workers that avoid {machine_to_avoid}...'
+            )
+            for worker in self.workers:
+                if worker.count > 0:
+                    for _ in range(worker.count * worker.weight):
+                        grabbag.append(worker)
+
+        if len(grabbag) == 0:
+            logger.warning('There are no available workers?!')
+            return None
+
+        worker = random.sample(grabbag, 1)[0]
+        assert worker.count > 0
+        worker.count -= 1
+        logger.debug(f'Chose worker {worker}')
+        return worker
 
 
 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
@@ -482,7 +503,9 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
 
 class RemoteExecutor(BaseExecutor):
     def __init__(
-        self, workers: List[RemoteWorkerRecord], policy: RemoteWorkerSelectionPolicy
+        self,
+        workers: List[RemoteWorkerRecord],
+        policy: RemoteWorkerSelectionPolicy,
     ) -> None:
         super().__init__()
         self.workers = workers
@@ -513,7 +536,7 @@ class RemoteExecutor(BaseExecutor):
         ) = self.run_periodic_heartbeat()
 
     @background_thread
-    def run_periodic_heartbeat(self, stop_event) -> None:
+    def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
         while not stop_event.is_set():
             time.sleep(5.0)
             logger.debug('Running periodic heartbeat code...')
@@ -521,8 +544,10 @@ class RemoteExecutor(BaseExecutor):
         logger.debug('Periodic heartbeat thread shutting down.')
 
     def heartbeat(self) -> None:
+        # Note: this is invoked on a background thread, not an
+        # executor thread.  Be careful what you do with it b/c it
+        # needs to get back and dump status again periodically.
         with self.status.lock:
-            # Dump regular progress report
             self.status.periodic_dump(self.total_bundles_submitted)
 
             # Look for bundles to reschedule via executor.submit
@@ -537,7 +562,7 @@ class RemoteExecutor(BaseExecutor):
         if (
             num_done > 2
             and num_idle_workers > 1
-            and (self.last_backup is None or (now - self.last_backup > 6.0))
+            and (self.last_backup is None or (now - self.last_backup > 9.0))
             and self.backup_lock.acquire(blocking=False)
         ):
             try:
@@ -1102,7 +1127,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username='scott',
                         machine='cheetah.house',
-                        weight=34,
+                        weight=30,
                         count=6,
                     ),
                 )
@@ -1132,7 +1157,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username='scott',
                         machine='puma.cabin',
-                        weight=25,
+                        weight=30,
                         count=6,
                     ),
                 )
@@ -1142,7 +1167,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username='scott',
                         machine='backup.house',
-                        weight=7,
+                        weight=8,
                         count=2,
                     ),
                 )