Do not adjust task count from a child process or thread; always
authorScott <[email protected]>
Sun, 30 Jan 2022 22:48:32 +0000 (14:48 -0800)
committerScott <[email protected]>
Sun, 30 Jan 2022 22:48:32 +0000 (14:48 -0800)
do it in the launcher thread/process.

executors.py

index 20aa9d2e325e5b5e54a77836f6117b91d99c36be..1df1877582362bdc9971d64c94df44f6d3e943cd 100644 (file)
@@ -74,10 +74,10 @@ def make_cloud_pickle(fun, *args, **kwargs):
 class BaseExecutor(ABC):
     def __init__(self, *, title=''):
         self.title = title
-        self.task_count = 0
         self.histogram = hist.SimpleHistogram(
             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
         )
+        self.task_count = 0
 
     @abstractmethod
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
@@ -88,9 +88,23 @@ class BaseExecutor(ABC):
         pass
 
     def adjust_task_count(self, delta: int) -> None:
+        """Change the task count.  Note: do not call this method from a
+        worker, it should only be called by the launcher process /
+        thread / machine.
+
+        """
         self.task_count += delta
+        print(f'(adjusted task count by {delta} to {self.task_count})')
         logger.debug(f'Executor current task count is {self.task_count}')
 
+    def get_task_count(self) -> int:
+        """Change the task count.  Note: do not call this method from a
+        worker, it should only be called by the launcher process /
+        thread / machine.
+
+        """
+        return self.task_count
+
 
 class ThreadExecutor(BaseExecutor):
     def __init__(self, max_workers: Optional[int] = None):
@@ -105,15 +119,10 @@ class ThreadExecutor(BaseExecutor):
             max_workers=workers, thread_name_prefix="thread_executor_helper"
         )
 
+    # This is run on a different thread; do not adjust task count here.
     def run_local_bundle(self, fun, *args, **kwargs):
         logger.debug(f"Running local bundle at {fun.__name__}")
-        start = time.time()
         result = fun(*args, **kwargs)
-        end = time.time()
-        self.adjust_task_count(-1)
-        duration = end - start
-        logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
-        self.histogram.add_item(duration)
         return result
 
     @overrides
@@ -123,9 +132,12 @@ class ThreadExecutor(BaseExecutor):
         newargs.append(function)
         for arg in args:
             newargs.append(arg)
-        return self._thread_pool_executor.submit(
+        start = time.time()
+        result = self._thread_pool_executor.submit(
             self.run_local_bundle, *newargs, **kwargs
         )
+        result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
+        result.add_done_callback(lambda _: self.adjust_task_count(-1))
 
     @overrides
     def shutdown(self, wait=True) -> None:
@@ -147,11 +159,11 @@ class ProcessExecutor(BaseExecutor):
             max_workers=workers,
         )
 
+    # This is run in another process; do not adjust task count here.
     def run_cloud_pickle(self, pickle):
         fun, args, kwargs = cloudpickle.loads(pickle)
         logger.debug(f"Running pickled bundle at {fun.__name__}")
         result = fun(*args, **kwargs)
-        self.adjust_task_count(-1)
         return result
 
     @overrides
@@ -161,6 +173,7 @@ class ProcessExecutor(BaseExecutor):
         pickle = make_cloud_pickle(function, *args, **kwargs)
         result = self._process_executor.submit(self.run_cloud_pickle, pickle)
         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
+        result.add_done_callback(lambda _: self.adjust_task_count(-1))
         return result
 
     @overrides