Adds a shutdown_if_idle method to the executors and also eliminates
authorScott <[email protected]>
Mon, 31 Jan 2022 01:36:47 +0000 (17:36 -0800)
committerScott <[email protected]>
Mon, 31 Jan 2022 01:36:47 +0000 (17:36 -0800)
a multiple-shutdown problem encountered when using them.

executors.py

index 1df1877582362bdc9971d64c94df44f6d3e943cd..5b77a42dc3d29ca6f42673a369e23f0962343c62 100644 (file)
@@ -87,6 +87,18 @@ class BaseExecutor(ABC):
     def shutdown(self, wait: bool = True) -> None:
         pass
 
+    def shutdown_if_idle(self) -> bool:
+        """Shutdown the executor and return True if the executor is idle
+        (i.e. there are no pending or active tasks).  Return False
+        otherwise.  Note: this should only be called by the launcher
+        process.
+
+        """
+        if self.task_count == 0:
+            self.shutdown()
+            return True
+        return False
+
     def adjust_task_count(self, delta: int) -> None:
         """Change the task count.  Note: do not call this method from a
         worker, it should only be called by the launcher process /
@@ -94,8 +106,7 @@ class BaseExecutor(ABC):
 
         """
         self.task_count += delta
-        print(f'(adjusted task count by {delta} to {self.task_count})')
-        logger.debug(f'Executor current task count is {self.task_count}')
+        logger.debug(f'Adjusted task count by {delta} to {self.task_count}')
 
     def get_task_count(self) -> int:
         """Change the task count.  Note: do not call this method from a
@@ -118,6 +129,7 @@ class ThreadExecutor(BaseExecutor):
         self._thread_pool_executor = fut.ThreadPoolExecutor(
             max_workers=workers, thread_name_prefix="thread_executor_helper"
         )
+        self.already_shutdown = False
 
     # This is run on a different thread; do not adjust task count here.
     def run_local_bundle(self, fun, *args, **kwargs):
@@ -127,6 +139,8 @@ class ThreadExecutor(BaseExecutor):
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         self.adjust_task_count(+1)
         newargs = []
         newargs.append(function)
@@ -141,9 +155,11 @@ class ThreadExecutor(BaseExecutor):
 
     @overrides
     def shutdown(self, wait=True) -> None:
-        logger.debug(f'Shutting down threadpool executor {self.title}')
-        print(self.histogram)
-        self._thread_pool_executor.shutdown(wait)
+        if not self.already_shutdown:
+            logger.debug(f'Shutting down threadpool executor {self.title}')
+            print(self.histogram)
+            self._thread_pool_executor.shutdown(wait)
+            self.already_shutdown = True
 
 
 class ProcessExecutor(BaseExecutor):
@@ -158,6 +174,7 @@ class ProcessExecutor(BaseExecutor):
         self._process_executor = fut.ProcessPoolExecutor(
             max_workers=workers,
         )
+        self.already_shutdown = False
 
     # This is run in another process; do not adjust task count here.
     def run_cloud_pickle(self, pickle):
@@ -168,6 +185,8 @@ class ProcessExecutor(BaseExecutor):
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         start = time.time()
         self.adjust_task_count(+1)
         pickle = make_cloud_pickle(function, *args, **kwargs)
@@ -178,9 +197,11 @@ class ProcessExecutor(BaseExecutor):
 
     @overrides
     def shutdown(self, wait=True) -> None:
-        logger.debug(f'Shutting down processpool executor {self.title}')
-        self._process_executor.shutdown(wait)
-        print(self.histogram)
+        if not self.already_shutdown:
+            logger.debug(f'Shutting down processpool executor {self.title}')
+            self._process_executor.shutdown(wait)
+            print(self.histogram)
+            self.already_shutdown = True
 
     def __getstate__(self):
         state = self.__dict__.copy()
@@ -534,6 +555,7 @@ class RemoteExecutor(BaseExecutor):
             self.heartbeat_thread,
             self.heartbeat_stop_event,
         ) = self.run_periodic_heartbeat()
+        self.already_shutdown = False
 
     @background_thread
     def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
@@ -1076,6 +1098,8 @@ class RemoteExecutor(BaseExecutor):
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         pickle = make_cloud_pickle(function, *args, **kwargs)
         bundle = self.create_original_bundle(pickle, function.__name__)
         self.total_bundles_submitted += 1
@@ -1083,11 +1107,13 @@ class RemoteExecutor(BaseExecutor):
 
     @overrides
     def shutdown(self, wait=True) -> None:
-        logging.debug(f'Shutting down RemoteExecutor {self.title}')
-        self.heartbeat_stop_event.set()
-        self.heartbeat_thread.join()
-        self._helper_executor.shutdown(wait)
-        print(self.histogram)
+        if not self.already_shutdown:
+            logging.debug(f'Shutting down RemoteExecutor {self.title}')
+            self.heartbeat_stop_event.set()
+            self.heartbeat_thread.join()
+            self._helper_executor.shutdown(wait)
+            print(self.histogram)
+            self.already_shutdown = True
 
 
 @singleton