Fix a recent bug in executors. Thread executor needs to return
[python_utils.git] / executors.py
index 46812c2b49203c2b23c021978e8e6fe334b80afa..47b4a89a88d693d535ed2e036c6288829505a005 100644 (file)
@@ -74,10 +74,10 @@ def make_cloud_pickle(fun, *args, **kwargs):
 class BaseExecutor(ABC):
     def __init__(self, *, title=''):
         self.title = title
-        self.task_count = 0
         self.histogram = hist.SimpleHistogram(
             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
         )
+        self.task_count = 0
 
     @abstractmethod
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
@@ -87,9 +87,34 @@ class BaseExecutor(ABC):
     def shutdown(self, wait: bool = True) -> None:
         pass
 
+    def shutdown_if_idle(self) -> bool:
+        """Shutdown the executor and return True if the executor is idle
+        (i.e. there are no pending or active tasks).  Return False
+        otherwise.  Note: this should only be called by the launcher
+        process.
+
+        """
+        if self.task_count == 0:
+            self.shutdown()
+            return True
+        return False
+
     def adjust_task_count(self, delta: int) -> None:
+        """Change the task count.  Note: do not call this method from a
+        worker, it should only be called by the launcher process /
+        thread / machine.
+
+        """
         self.task_count += delta
-        logger.debug(f'Executor current task count is {self.task_count}')
+        logger.debug(f'Adjusted task count by {delta} to {self.task_count}')
+
+    def get_task_count(self) -> int:
+        """Change the task count.  Note: do not call this method from a
+        worker, it should only be called by the launcher process /
+        thread / machine.
+
+        """
+        return self.task_count
 
 
 class ThreadExecutor(BaseExecutor):
@@ -104,34 +129,38 @@ class ThreadExecutor(BaseExecutor):
         self._thread_pool_executor = fut.ThreadPoolExecutor(
             max_workers=workers, thread_name_prefix="thread_executor_helper"
         )
+        self.already_shutdown = False
 
+    # This is run on a different thread; do not adjust task count here.
     def run_local_bundle(self, fun, *args, **kwargs):
         logger.debug(f"Running local bundle at {fun.__name__}")
-        start = time.time()
         result = fun(*args, **kwargs)
-        end = time.time()
-        self.adjust_task_count(-1)
-        duration = end - start
-        logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
-        self.histogram.add_item(duration)
         return result
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         self.adjust_task_count(+1)
         newargs = []
         newargs.append(function)
         for arg in args:
             newargs.append(arg)
-        return self._thread_pool_executor.submit(
+        start = time.time()
+        result = self._thread_pool_executor.submit(
             self.run_local_bundle, *newargs, **kwargs
         )
+        result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
+        result.add_done_callback(lambda _: self.adjust_task_count(-1))
+        return result
 
     @overrides
     def shutdown(self, wait=True) -> None:
-        logger.debug(f'Shutting down threadpool executor {self.title}')
-        print(self.histogram)
-        self._thread_pool_executor.shutdown(wait)
+        if not self.already_shutdown:
+            logger.debug(f'Shutting down threadpool executor {self.title}')
+            print(self.histogram)
+            self._thread_pool_executor.shutdown(wait)
+            self.already_shutdown = True
 
 
 class ProcessExecutor(BaseExecutor):
@@ -146,30 +175,34 @@ class ProcessExecutor(BaseExecutor):
         self._process_executor = fut.ProcessPoolExecutor(
             max_workers=workers,
         )
+        self.already_shutdown = False
 
+    # This is run in another process; do not adjust task count here.
     def run_cloud_pickle(self, pickle):
         fun, args, kwargs = cloudpickle.loads(pickle)
         logger.debug(f"Running pickled bundle at {fun.__name__}")
         result = fun(*args, **kwargs)
-        self.adjust_task_count(-1)
         return result
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         start = time.time()
         self.adjust_task_count(+1)
         pickle = make_cloud_pickle(function, *args, **kwargs)
         result = self._process_executor.submit(self.run_cloud_pickle, pickle)
-        result.add_done_callback(
-            lambda _: self.histogram.add_item(time.time() - start)
-        )
+        result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
+        result.add_done_callback(lambda _: self.adjust_task_count(-1))
         return result
 
     @overrides
     def shutdown(self, wait=True) -> None:
-        logger.debug(f'Shutting down processpool executor {self.title}')
-        self._process_executor.shutdown(wait)
-        print(self.histogram)
+        if not self.already_shutdown:
+            logger.debug(f'Shutting down processpool executor {self.title}')
+            self._process_executor.shutdown(wait)
+            print(self.histogram)
+            self.already_shutdown = True
 
     def __getstate__(self):
         state = self.__dict__.copy()
@@ -258,9 +291,7 @@ class RemoteExecutorStatus:
         self.finished_bundle_timings_per_worker: Dict[
             RemoteWorkerRecord, List[float]
         ] = {}
-        self.in_flight_bundles_by_worker: Dict[
-            RemoteWorkerRecord, Set[str]
-        ] = {}
+        self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
         self.finished_bundle_timings: List[float] = []
         self.last_periodic_dump: Optional[float] = None
@@ -270,9 +301,7 @@ class RemoteExecutorStatus:
         # as a memory fence for modifications to bundle.
         self.lock: threading.Lock = threading.Lock()
 
-    def record_acquire_worker(
-        self, worker: RemoteWorkerRecord, uuid: str
-    ) -> None:
+    def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
         with self.lock:
             self.record_acquire_worker_already_locked(worker, uuid)
 
@@ -290,9 +319,7 @@ class RemoteExecutorStatus:
         with self.lock:
             self.record_bundle_details_already_locked(details)
 
-    def record_bundle_details_already_locked(
-        self, details: BundleDetails
-    ) -> None:
+    def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
         assert self.lock.locked()
         self.bundle_details_by_uuid[details.uuid] = details
 
@@ -303,9 +330,7 @@ class RemoteExecutorStatus:
         was_cancelled: bool,
     ) -> None:
         with self.lock:
-            self.record_release_worker_already_locked(
-                worker, uuid, was_cancelled
-            )
+            self.record_release_worker_already_locked(worker, uuid, was_cancelled)
 
     def record_release_worker_already_locked(
         self,
@@ -377,11 +402,7 @@ class RemoteExecutorStatus:
                 ret += f'    ...{in_flight} bundles currently in flight:\n'
                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
-                    pid = (
-                        str(details.pid)
-                        if (details and details.pid != 0)
-                        else "TBD"
-                    )
+                    pid = str(details.pid) if (details and details.pid != 0) else "TBD"
                     if self.start_per_bundle[bundle_uuid] is not None:
                         sec = ts - self.start_per_bundle[bundle_uuid]
                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
@@ -412,10 +433,7 @@ class RemoteExecutorStatus:
         assert self.lock.locked()
         self.total_bundles_submitted = total_bundles_submitted
         ts = time.time()
-        if (
-            self.last_periodic_dump is None
-            or ts - self.last_periodic_dump > 5.0
-        ):
+        if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
             print(self)
             self.last_periodic_dump = ts
 
@@ -429,9 +447,7 @@ class RemoteWorkerSelectionPolicy(ABC):
         pass
 
     @abstractmethod
-    def acquire_worker(
-        self, machine_to_avoid=None
-    ) -> Optional[RemoteWorkerRecord]:
+    def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
         pass
 
 
@@ -444,26 +460,32 @@ class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
         return False
 
     @overrides
-    def acquire_worker(
-        self, machine_to_avoid=None
-    ) -> Optional[RemoteWorkerRecord]:
+    def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
         grabbag = []
         for worker in self.workers:
-            for x in range(0, worker.count):
-                for y in range(0, worker.weight):
-                    grabbag.append(worker)
-
-        for _ in range(0, 5):
-            random.shuffle(grabbag)
-            worker = grabbag[0]
-            if worker.machine != machine_to_avoid or _ > 2:
+            if worker.machine != machine_to_avoid:
                 if worker.count > 0:
-                    worker.count -= 1
-                    logger.debug(f'Selected worker {worker}')
-                    return worker
-        msg = 'Unexpectedly could not find a worker, retrying...'
-        logger.warning(msg)
-        return None
+                    for _ in range(worker.count * worker.weight):
+                        grabbag.append(worker)
+
+        if len(grabbag) == 0:
+            logger.debug(
+                f'There are no available workers that avoid {machine_to_avoid}...'
+            )
+            for worker in self.workers:
+                if worker.count > 0:
+                    for _ in range(worker.count * worker.weight):
+                        grabbag.append(worker)
+
+        if len(grabbag) == 0:
+            logger.warning('There are no available workers?!')
+            return None
+
+        worker = random.sample(grabbag, 1)[0]
+        assert worker.count > 0
+        worker.count -= 1
+        logger.debug(f'Chose worker {worker}')
+        return worker
 
 
 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
@@ -534,9 +556,10 @@ class RemoteExecutor(BaseExecutor):
             self.heartbeat_thread,
             self.heartbeat_stop_event,
         ) = self.run_periodic_heartbeat()
+        self.already_shutdown = False
 
     @background_thread
-    def run_periodic_heartbeat(self, stop_event) -> None:
+    def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
         while not stop_event.is_set():
             time.sleep(5.0)
             logger.debug('Running periodic heartbeat code...')
@@ -544,8 +567,10 @@ class RemoteExecutor(BaseExecutor):
         logger.debug('Periodic heartbeat thread shutting down.')
 
     def heartbeat(self) -> None:
+        # Note: this is invoked on a background thread, not an
+        # executor thread.  Be careful what you do with it b/c it
+        # needs to get back and dump status again periodically.
         with self.status.lock:
-            # Dump regular progress report
             self.status.periodic_dump(self.total_bundles_submitted)
 
             # Look for bundles to reschedule via executor.submit
@@ -560,7 +585,7 @@ class RemoteExecutor(BaseExecutor):
         if (
             num_done > 2
             and num_idle_workers > 1
-            and (self.last_backup is None or (now - self.last_backup > 6.0))
+            and (self.last_backup is None or (now - self.last_backup > 9.0))
             and self.backup_lock.acquire(blocking=False)
         ):
             try:
@@ -585,9 +610,7 @@ class RemoteExecutor(BaseExecutor):
                             break
 
                     for uuid in bundle_uuids:
-                        bundle = self.status.bundle_details_by_uuid.get(
-                            uuid, None
-                        )
+                        bundle = self.status.bundle_details_by_uuid.get(uuid, None)
                         if (
                             bundle is not None
                             and bundle.src_bundle is None
@@ -678,9 +701,7 @@ class RemoteExecutor(BaseExecutor):
         logger.critical(msg)
         raise Exception(msg)
 
-    def release_worker(
-        self, bundle: BundleDetails, *, was_cancelled=True
-    ) -> None:
+    def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
         worker = bundle.worker
         assert worker is not None
         logger.debug(f'Released worker {worker}')
@@ -764,14 +785,14 @@ class RemoteExecutor(BaseExecutor):
         # Send input code / data to worker machine if it's not local.
         if hostname not in machine:
             try:
-                cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
+                cmd = (
+                    f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
+                )
                 start_ts = time.time()
                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
                 run_silently(cmd)
                 xfer_latency = time.time() - start_ts
-                logger.debug(
-                    f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s."
-                )
+                logger.debug(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.")
             except Exception as e:
                 self.release_worker(bundle)
                 if is_original:
@@ -804,9 +825,7 @@ class RemoteExecutor(BaseExecutor):
             f' /home/scott/lib/python_modules/remote_worker.py'
             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
         )
-        logger.debug(
-            f'{bundle}: Executing {cmd} in the background to kick off work...'
-        )
+        logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
         p = cmd_in_background(cmd, silent=True)
         bundle.pid = p.pid
         logger.debug(
@@ -935,9 +954,7 @@ class RemoteExecutor(BaseExecutor):
                 # Re-raise the exception; the code in wait_for_process may
                 # decide to emergency_retry_nasty_bundle here.
                 raise Exception(e)
-            logger.debug(
-                f'Removing local (master) {code_file} and {result_file}.'
-            )
+            logger.debug(f'Removing local (master) {code_file} and {result_file}.')
             os.remove(f'{result_file}')
             os.remove(f'{code_file}')
 
@@ -1082,6 +1099,8 @@ class RemoteExecutor(BaseExecutor):
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         pickle = make_cloud_pickle(function, *args, **kwargs)
         bundle = self.create_original_bundle(pickle, function.__name__)
         self.total_bundles_submitted += 1
@@ -1089,11 +1108,13 @@ class RemoteExecutor(BaseExecutor):
 
     @overrides
     def shutdown(self, wait=True) -> None:
-        logging.debug(f'Shutting down RemoteExecutor {self.title}')
-        self.heartbeat_stop_event.set()
-        self.heartbeat_thread.join()
-        self._helper_executor.shutdown(wait)
-        print(self.histogram)
+        if not self.already_shutdown:
+            logging.debug(f'Shutting down RemoteExecutor {self.title}')
+            self.heartbeat_stop_event.set()
+            self.heartbeat_thread.join()
+            self._helper_executor.shutdown(wait)
+            print(self.histogram)
+            self.already_shutdown = True
 
 
 @singleton
@@ -1133,7 +1154,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username='scott',
                         machine='cheetah.house',
-                        weight=34,
+                        weight=30,
                         count=6,
                     ),
                 )
@@ -1163,7 +1184,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username='scott',
                         machine='puma.cabin',
-                        weight=25,
+                        weight=30,
                         count=6,
                     ),
                 )
@@ -1173,7 +1194,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username='scott',
                         machine='backup.house',
-                        weight=7,
+                        weight=8,
                         count=2,
                     ),
                 )