Used isort to sort imports. Also added to the git pre-commit hook.
[python_utils.git] / executors.py
index 633b11b3b616cc2536ea514d274cbf2874a99282..34528a33c2d10236cd7527fe53ffd027fa8020ca 100644 (file)
@@ -2,33 +2,32 @@
 
 from __future__ import annotations
 
 
 from __future__ import annotations
 
-from abc import ABC, abstractmethod
 import concurrent.futures as fut
 import concurrent.futures as fut
-from collections import defaultdict
-from dataclasses import dataclass
 import logging
 import logging
-import numpy
 import os
 import platform
 import random
 import subprocess
 import threading
 import time
 import os
 import platform
 import random
 import subprocess
 import threading
 import time
-from typing import Any, Callable, Dict, List, Optional, Set
 import warnings
 import warnings
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Set
 
 import cloudpickle  # type: ignore
 
 import cloudpickle  # type: ignore
+import numpy
 from overrides import overrides
 
 from overrides import overrides
 
-from ansi import bg, fg, underline, reset
 import argparse_utils
 import config
 import argparse_utils
 import config
-from decorator_utils import singleton
-from exec_utils import run_silently, cmd_in_background, cmd_with_timeout
 import histogram as hist
 import histogram as hist
+from ansi import bg, fg, reset, underline
+from decorator_utils import singleton
+from exec_utils import cmd_in_background, cmd_with_timeout, run_silently
 from thread_utils import background_thread
 
 from thread_utils import background_thread
 
-
 logger = logging.getLogger(__name__)
 
 parser = config.add_commandline_args(
 logger = logging.getLogger(__name__)
 
 parser = config.add_commandline_args(
@@ -84,10 +83,10 @@ class BaseExecutor(ABC):
         pass
 
     @abstractmethod
         pass
 
     @abstractmethod
-    def shutdown(self, wait: bool = True) -> None:
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
         pass
 
         pass
 
-    def shutdown_if_idle(self) -> bool:
+    def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
         """Shutdown the executor and return True if the executor is idle
         (i.e. there are no pending or active tasks).  Return False
         otherwise.  Note: this should only be called by the launcher
         """Shutdown the executor and return True if the executor is idle
         (i.e. there are no pending or active tasks).  Return False
         otherwise.  Note: this should only be called by the launcher
@@ -95,7 +94,7 @@ class BaseExecutor(ABC):
 
         """
         if self.task_count == 0:
 
         """
         if self.task_count == 0:
-            self.shutdown()
+            self.shutdown(wait=True, quiet=quiet)
             return True
         return False
 
             return True
         return False
 
@@ -155,11 +154,12 @@ class ThreadExecutor(BaseExecutor):
         return result
 
     @overrides
         return result
 
     @overrides
-    def shutdown(self, wait=True) -> None:
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
         if not self.already_shutdown:
             logger.debug(f'Shutting down threadpool executor {self.title}')
         if not self.already_shutdown:
             logger.debug(f'Shutting down threadpool executor {self.title}')
-            print(self.histogram.__repr__(label_formatter='%ds'))
             self._thread_pool_executor.shutdown(wait)
             self._thread_pool_executor.shutdown(wait)
+            if not quiet:
+                print(self.histogram.__repr__(label_formatter='%ds'))
             self.already_shutdown = True
 
 
             self.already_shutdown = True
 
 
@@ -197,11 +197,12 @@ class ProcessExecutor(BaseExecutor):
         return result
 
     @overrides
         return result
 
     @overrides
-    def shutdown(self, wait=True) -> None:
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
         if not self.already_shutdown:
             logger.debug(f'Shutting down processpool executor {self.title}')
             self._process_executor.shutdown(wait)
         if not self.already_shutdown:
             logger.debug(f'Shutting down processpool executor {self.title}')
             self._process_executor.shutdown(wait)
-            print(self.histogram.__repr__(label_formatter='%ds'))
+            if not quiet:
+                print(self.histogram.__repr__(label_formatter='%ds'))
             self.already_shutdown = True
 
     def __getstate__(self):
             self.already_shutdown = True
 
     def __getstate__(self):
@@ -246,7 +247,7 @@ class BundleDetails:
     end_ts: float
     slower_than_local_p95: bool
     slower_than_global_p95: bool
     end_ts: float
     slower_than_local_p95: bool
     slower_than_global_p95: bool
-    src_bundle: BundleDetails
+    src_bundle: Optional[BundleDetails]
     is_cancelled: threading.Event
     was_cancelled: bool
     backup_bundles: Optional[List[BundleDetails]]
     is_cancelled: threading.Event
     was_cancelled: bool
     backup_bundles: Optional[List[BundleDetails]]
@@ -286,7 +287,7 @@ class RemoteExecutorStatus:
         self.worker_count: int = total_worker_count
         self.known_workers: Set[RemoteWorkerRecord] = set()
         self.start_time: float = time.time()
         self.worker_count: int = total_worker_count
         self.known_workers: Set[RemoteWorkerRecord] = set()
         self.start_time: float = time.time()
-        self.start_per_bundle: Dict[str, float] = defaultdict(float)
+        self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
         self.end_per_bundle: Dict[str, float] = defaultdict(float)
         self.finished_bundle_timings_per_worker: Dict[
             RemoteWorkerRecord, List[float]
         self.end_per_bundle: Dict[str, float] = defaultdict(float)
         self.finished_bundle_timings_per_worker: Dict[
             RemoteWorkerRecord, List[float]
@@ -343,7 +344,9 @@ class RemoteExecutorStatus:
         self.end_per_bundle[uuid] = ts
         self.in_flight_bundles_by_worker[worker].remove(uuid)
         if not was_cancelled:
         self.end_per_bundle[uuid] = ts
         self.in_flight_bundles_by_worker[worker].remove(uuid)
         if not was_cancelled:
-            bundle_latency = ts - self.start_per_bundle[uuid]
+            start = self.start_per_bundle[uuid]
+            assert start
+            bundle_latency = ts - start
             x = self.finished_bundle_timings_per_worker.get(worker, list())
             x.append(bundle_latency)
             self.finished_bundle_timings_per_worker[worker] = x
             x = self.finished_bundle_timings_per_worker.get(worker, list())
             x.append(bundle_latency)
             self.finished_bundle_timings_per_worker[worker] = x
@@ -834,9 +837,10 @@ class RemoteExecutor(BaseExecutor):
         return self.wait_for_process(p, bundle, 0)
 
     def wait_for_process(
         return self.wait_for_process(p, bundle, 0)
 
     def wait_for_process(
-        self, p: subprocess.Popen, bundle: BundleDetails, depth: int
+        self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
     ) -> Any:
         machine = bundle.machine
     ) -> Any:
         machine = bundle.machine
+        assert p
         pid = p.pid
         if depth > 3:
             logger.error(
         pid = p.pid
         if depth > 3:
             logger.error(
@@ -979,10 +983,12 @@ class RemoteExecutor(BaseExecutor):
 
             # Tell the original to stop if we finished first.
             if not was_cancelled:
 
             # Tell the original to stop if we finished first.
             if not was_cancelled:
+                orig_bundle = bundle.src_bundle
+                assert orig_bundle
                 logger.debug(
                 logger.debug(
-                    f'{bundle}: Notifying original {bundle.src_bundle.uuid} we beat them to it.'
+                    f'{bundle}: Notifying original {orig_bundle.uuid} we beat them to it.'
                 )
                 )
-                bundle.src_bundle.is_cancelled.set()
+                orig_bundle.is_cancelled.set()
         self.release_worker(bundle, was_cancelled=was_cancelled)
         return result
 
         self.release_worker(bundle, was_cancelled=was_cancelled)
         return result
 
@@ -1066,7 +1072,9 @@ class RemoteExecutor(BaseExecutor):
         # they will move the result_file to this machine and let
         # the original pick them up and unpickle them.
 
         # they will move the result_file to this machine and let
         # the original pick them up and unpickle them.
 
-    def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> fut.Future:
+    def emergency_retry_nasty_bundle(
+        self, bundle: BundleDetails
+    ) -> Optional[fut.Future]:
         is_original = bundle.src_bundle is None
         bundle.worker = None
         avoid_last_machine = bundle.machine
         is_original = bundle.src_bundle is None
         bundle.worker = None
         avoid_last_machine = bundle.machine
@@ -1107,13 +1115,14 @@ class RemoteExecutor(BaseExecutor):
         return self._helper_executor.submit(self.launch, bundle)
 
     @overrides
         return self._helper_executor.submit(self.launch, bundle)
 
     @overrides
-    def shutdown(self, wait=True) -> None:
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
         if not self.already_shutdown:
             logging.debug(f'Shutting down RemoteExecutor {self.title}')
             self.heartbeat_stop_event.set()
             self.heartbeat_thread.join()
             self._helper_executor.shutdown(wait)
         if not self.already_shutdown:
             logging.debug(f'Shutting down RemoteExecutor {self.title}')
             self.heartbeat_stop_event.set()
             self.heartbeat_thread.join()
             self._helper_executor.shutdown(wait)
-            print(self.histogram.__repr__(label_formatter='%ds'))
+            if not quiet:
+                print(self.histogram.__repr__(label_formatter='%ds'))
             self.already_shutdown = True
 
 
             self.already_shutdown = True
 
 
@@ -1212,11 +1221,11 @@ class DefaultExecutors(object):
 
     def shutdown(self) -> None:
         if self.thread_executor is not None:
 
     def shutdown(self) -> None:
         if self.thread_executor is not None:
-            self.thread_executor.shutdown()
+            self.thread_executor.shutdown(wait=True, quiet=True)
             self.thread_executor = None
         if self.process_executor is not None:
             self.thread_executor = None
         if self.process_executor is not None:
-            self.process_executor.shutdown()
+            self.process_executor.shutdown(wait=True, quiet=True)
             self.process_executor = None
         if self.remote_executor is not None:
             self.process_executor = None
         if self.remote_executor is not None:
-            self.remote_executor.shutdown()
+            self.remote_executor.shutdown(wait=True, quiet=True)
             self.remote_executor = None
             self.remote_executor = None