More mypy cleanup... ugh.
[python_utils.git] / executors.py
index 1df1877582362bdc9971d64c94df44f6d3e943cd..e95ed716043b4962cd939b6d25885fd87826466a 100644 (file)
@@ -84,9 +84,21 @@ class BaseExecutor(ABC):
         pass
 
     @abstractmethod
-    def shutdown(self, wait: bool = True) -> None:
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
         pass
 
+    def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
+        """Shutdown the executor and return True if the executor is idle
+        (i.e. there are no pending or active tasks).  Return False
+        otherwise.  Note: this should only be called by the launcher
+        process.
+
+        """
+        if self.task_count == 0:
+            self.shutdown(wait=True, quiet=quiet)
+            return True
+        return False
+
     def adjust_task_count(self, delta: int) -> None:
         """Change the task count.  Note: do not call this method from a
         worker, it should only be called by the launcher process /
@@ -94,8 +106,7 @@ class BaseExecutor(ABC):
 
         """
         self.task_count += delta
-        print(f'(adjusted task count by {delta} to {self.task_count})')
-        logger.debug(f'Executor current task count is {self.task_count}')
+        logger.debug(f'Adjusted task count by {delta} to {self.task_count}')
 
     def get_task_count(self) -> int:
         """Change the task count.  Note: do not call this method from a
@@ -118,6 +129,7 @@ class ThreadExecutor(BaseExecutor):
         self._thread_pool_executor = fut.ThreadPoolExecutor(
             max_workers=workers, thread_name_prefix="thread_executor_helper"
         )
+        self.already_shutdown = False
 
     # This is run on a different thread; do not adjust task count here.
     def run_local_bundle(self, fun, *args, **kwargs):
@@ -127,6 +139,8 @@ class ThreadExecutor(BaseExecutor):
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         self.adjust_task_count(+1)
         newargs = []
         newargs.append(function)
@@ -138,12 +152,16 @@ class ThreadExecutor(BaseExecutor):
         )
         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
         result.add_done_callback(lambda _: self.adjust_task_count(-1))
+        return result
 
     @overrides
-    def shutdown(self, wait=True) -> None:
-        logger.debug(f'Shutting down threadpool executor {self.title}')
-        print(self.histogram)
-        self._thread_pool_executor.shutdown(wait)
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
+        if not self.already_shutdown:
+            logger.debug(f'Shutting down threadpool executor {self.title}')
+            self._thread_pool_executor.shutdown(wait)
+            if not quiet:
+                print(self.histogram.__repr__(label_formatter='%ds'))
+            self.already_shutdown = True
 
 
 class ProcessExecutor(BaseExecutor):
@@ -158,6 +176,7 @@ class ProcessExecutor(BaseExecutor):
         self._process_executor = fut.ProcessPoolExecutor(
             max_workers=workers,
         )
+        self.already_shutdown = False
 
     # This is run in another process; do not adjust task count here.
     def run_cloud_pickle(self, pickle):
@@ -168,6 +187,8 @@ class ProcessExecutor(BaseExecutor):
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         start = time.time()
         self.adjust_task_count(+1)
         pickle = make_cloud_pickle(function, *args, **kwargs)
@@ -177,10 +198,13 @@ class ProcessExecutor(BaseExecutor):
         return result
 
     @overrides
-    def shutdown(self, wait=True) -> None:
-        logger.debug(f'Shutting down processpool executor {self.title}')
-        self._process_executor.shutdown(wait)
-        print(self.histogram)
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
+        if not self.already_shutdown:
+            logger.debug(f'Shutting down processpool executor {self.title}')
+            self._process_executor.shutdown(wait)
+            if not quiet:
+                print(self.histogram.__repr__(label_formatter='%ds'))
+            self.already_shutdown = True
 
     def __getstate__(self):
         state = self.__dict__.copy()
@@ -534,6 +558,7 @@ class RemoteExecutor(BaseExecutor):
             self.heartbeat_thread,
             self.heartbeat_stop_event,
         ) = self.run_periodic_heartbeat()
+        self.already_shutdown = False
 
     @background_thread
     def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
@@ -1076,18 +1101,23 @@ class RemoteExecutor(BaseExecutor):
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         pickle = make_cloud_pickle(function, *args, **kwargs)
         bundle = self.create_original_bundle(pickle, function.__name__)
         self.total_bundles_submitted += 1
         return self._helper_executor.submit(self.launch, bundle)
 
     @overrides
-    def shutdown(self, wait=True) -> None:
-        logging.debug(f'Shutting down RemoteExecutor {self.title}')
-        self.heartbeat_stop_event.set()
-        self.heartbeat_thread.join()
-        self._helper_executor.shutdown(wait)
-        print(self.histogram)
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
+        if not self.already_shutdown:
+            logging.debug(f'Shutting down RemoteExecutor {self.title}')
+            self.heartbeat_stop_event.set()
+            self.heartbeat_thread.join()
+            self._helper_executor.shutdown(wait)
+            if not quiet:
+                print(self.histogram.__repr__(label_formatter='%ds'))
+            self.already_shutdown = True
 
 
 @singleton
@@ -1185,11 +1215,11 @@ class DefaultExecutors(object):
 
     def shutdown(self) -> None:
         if self.thread_executor is not None:
-            self.thread_executor.shutdown()
+            self.thread_executor.shutdown(wait=True, quiet=True)
             self.thread_executor = None
         if self.process_executor is not None:
-            self.process_executor.shutdown()
+            self.process_executor.shutdown(wait=True, quiet=True)
             self.process_executor = None
         if self.remote_executor is not None:
-            self.remote_executor.shutdown()
+            self.remote_executor.shutdown(wait=True, quiet=True)
             self.remote_executor = None