Ignore integration test results in code coverage report.
[python_utils.git] / executors.py
index 3786954a418c257bd2c7d0ae81d1455c0d8ea1cc..69330129c9b86ec8c9710b2586b24d7c2dfee8f6 100644 (file)
@@ -1,34 +1,34 @@
 #!/usr/bin/env python3
+# -*- coding: utf-8 -*-
 
 from __future__ import annotations
 
-from abc import ABC, abstractmethod
 import concurrent.futures as fut
-from collections import defaultdict
-from dataclasses import dataclass
 import logging
-import numpy
 import os
 import platform
 import random
 import subprocess
 import threading
 import time
-from typing import Any, Callable, Dict, List, Optional, Set
 import warnings
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Set
 
 import cloudpickle  # type: ignore
+import numpy
 from overrides import overrides
 
-from ansi import bg, fg, underline, reset
 import argparse_utils
 import config
-from decorator_utils import singleton
-from exec_utils import run_silently, cmd_in_background, cmd_with_timeout
 import histogram as hist
+from ansi import bg, fg, reset, underline
+from decorator_utils import singleton
+from exec_utils import cmd_in_background, cmd_with_timeout, run_silently
 from thread_utils import background_thread
 
-
 logger = logging.getLogger(__name__)
 
 parser = config.add_commandline_args(
@@ -74,22 +74,47 @@ def make_cloud_pickle(fun, *args, **kwargs):
 class BaseExecutor(ABC):
     def __init__(self, *, title=''):
         self.title = title
-        self.task_count = 0
         self.histogram = hist.SimpleHistogram(
             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
         )
+        self.task_count = 0
 
     @abstractmethod
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
         pass
 
     @abstractmethod
-    def shutdown(self, wait: bool = True) -> None:
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
         pass
 
+    def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
+        """Shutdown the executor and return True if the executor is idle
+        (i.e. there are no pending or active tasks).  Return False
+        otherwise.  Note: this should only be called by the launcher
+        process.
+
+        """
+        if self.task_count == 0:
+            self.shutdown(wait=True, quiet=quiet)
+            return True
+        return False
+
     def adjust_task_count(self, delta: int) -> None:
+        """Change the task count.  Note: do not call this method from a
+        worker, it should only be called by the launcher process /
+        thread / machine.
+
+        """
         self.task_count += delta
-        logger.debug(f'Executor current task count is {self.task_count}')
+        logger.debug(f'Adjusted task count by {delta} to {self.task_count}')
+
+    def get_task_count(self) -> int:
+        """Change the task count.  Note: do not call this method from a
+        worker, it should only be called by the launcher process /
+        thread / machine.
+
+        """
+        return self.task_count
 
 
 class ThreadExecutor(BaseExecutor):
@@ -104,34 +129,37 @@ class ThreadExecutor(BaseExecutor):
         self._thread_pool_executor = fut.ThreadPoolExecutor(
             max_workers=workers, thread_name_prefix="thread_executor_helper"
         )
+        self.already_shutdown = False
 
+    # This is run on a different thread; do not adjust task count here.
     def run_local_bundle(self, fun, *args, **kwargs):
         logger.debug(f"Running local bundle at {fun.__name__}")
-        start = time.time()
         result = fun(*args, **kwargs)
-        end = time.time()
-        self.adjust_task_count(-1)
-        duration = end - start
-        logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
-        self.histogram.add_item(duration)
         return result
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         self.adjust_task_count(+1)
         newargs = []
         newargs.append(function)
         for arg in args:
             newargs.append(arg)
-        return self._thread_pool_executor.submit(
-            self.run_local_bundle, *newargs, **kwargs
-        )
+        start = time.time()
+        result = self._thread_pool_executor.submit(self.run_local_bundle, *newargs, **kwargs)
+        result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
+        result.add_done_callback(lambda _: self.adjust_task_count(-1))
+        return result
 
     @overrides
-    def shutdown(self, wait=True) -> None:
-        logger.debug(f'Shutting down threadpool executor {self.title}')
-        print(self.histogram)
-        self._thread_pool_executor.shutdown(wait)
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
+        if not self.already_shutdown:
+            logger.debug(f'Shutting down threadpool executor {self.title}')
+            self._thread_pool_executor.shutdown(wait)
+            if not quiet:
+                print(self.histogram.__repr__(label_formatter='%ds'))
+            self.already_shutdown = True
 
 
 class ProcessExecutor(BaseExecutor):
@@ -146,28 +174,35 @@ class ProcessExecutor(BaseExecutor):
         self._process_executor = fut.ProcessPoolExecutor(
             max_workers=workers,
         )
+        self.already_shutdown = False
 
+    # This is run in another process; do not adjust task count here.
     def run_cloud_pickle(self, pickle):
         fun, args, kwargs = cloudpickle.loads(pickle)
         logger.debug(f"Running pickled bundle at {fun.__name__}")
         result = fun(*args, **kwargs)
-        self.adjust_task_count(-1)
         return result
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         start = time.time()
         self.adjust_task_count(+1)
         pickle = make_cloud_pickle(function, *args, **kwargs)
         result = self._process_executor.submit(self.run_cloud_pickle, pickle)
         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
+        result.add_done_callback(lambda _: self.adjust_task_count(-1))
         return result
 
     @overrides
-    def shutdown(self, wait=True) -> None:
-        logger.debug(f'Shutting down processpool executor {self.title}')
-        self._process_executor.shutdown(wait)
-        print(self.histogram)
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
+        if not self.already_shutdown:
+            logger.debug(f'Shutting down processpool executor {self.title}')
+            self._process_executor.shutdown(wait)
+            if not quiet:
+                print(self.histogram.__repr__(label_formatter='%ds'))
+            self.already_shutdown = True
 
     def __getstate__(self):
         state = self.__dict__.copy()
@@ -211,7 +246,7 @@ class BundleDetails:
     end_ts: float
     slower_than_local_p95: bool
     slower_than_global_p95: bool
-    src_bundle: BundleDetails
+    src_bundle: Optional[BundleDetails]
     is_cancelled: threading.Event
     was_cancelled: bool
     backup_bundles: Optional[List[BundleDetails]]
@@ -251,11 +286,9 @@ class RemoteExecutorStatus:
         self.worker_count: int = total_worker_count
         self.known_workers: Set[RemoteWorkerRecord] = set()
         self.start_time: float = time.time()
-        self.start_per_bundle: Dict[str, float] = defaultdict(float)
+        self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
         self.end_per_bundle: Dict[str, float] = defaultdict(float)
-        self.finished_bundle_timings_per_worker: Dict[
-            RemoteWorkerRecord, List[float]
-        ] = {}
+        self.finished_bundle_timings_per_worker: Dict[RemoteWorkerRecord, List[float]] = {}
         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
         self.finished_bundle_timings: List[float] = []
@@ -270,9 +303,7 @@ class RemoteExecutorStatus:
         with self.lock:
             self.record_acquire_worker_already_locked(worker, uuid)
 
-    def record_acquire_worker_already_locked(
-        self, worker: RemoteWorkerRecord, uuid: str
-    ) -> None:
+    def record_acquire_worker_already_locked(self, worker: RemoteWorkerRecord, uuid: str) -> None:
         assert self.lock.locked()
         self.known_workers.add(worker)
         self.start_per_bundle[uuid] = None
@@ -308,7 +339,9 @@ class RemoteExecutorStatus:
         self.end_per_bundle[uuid] = ts
         self.in_flight_bundles_by_worker[worker].remove(uuid)
         if not was_cancelled:
-            bundle_latency = ts - self.start_per_bundle[uuid]
+            start = self.start_per_bundle[uuid]
+            assert start is not None
+            bundle_latency = ts - start
             x = self.finished_bundle_timings_per_worker.get(worker, list())
             x.append(bundle_latency)
             self.finished_bundle_timings_per_worker[worker] = x
@@ -428,21 +461,27 @@ class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
         grabbag = []
         for worker in self.workers:
-            for x in range(0, worker.count):
-                for y in range(0, worker.weight):
-                    grabbag.append(worker)
-
-        for _ in range(0, 5):
-            random.shuffle(grabbag)
-            worker = grabbag[0]
-            if worker.machine != machine_to_avoid or _ > 2:
+            if worker.machine != machine_to_avoid:
+                if worker.count > 0:
+                    for _ in range(worker.count * worker.weight):
+                        grabbag.append(worker)
+
+        if len(grabbag) == 0:
+            logger.debug(f'There are no available workers that avoid {machine_to_avoid}...')
+            for worker in self.workers:
                 if worker.count > 0:
-                    worker.count -= 1
-                    logger.debug(f'Selected worker {worker}')
-                    return worker
-        msg = 'Unexpectedly could not find a worker, retrying...'
-        logger.warning(msg)
-        return None
+                    for _ in range(worker.count * worker.weight):
+                        grabbag.append(worker)
+
+        if len(grabbag) == 0:
+            logger.warning('There are no available workers?!')
+            return None
+
+        worker = random.sample(grabbag, 1)[0]
+        assert worker.count > 0
+        worker.count -= 1
+        logger.debug(f'Chose worker {worker}')
+        return worker
 
 
 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
@@ -457,9 +496,7 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
         return False
 
     @overrides
-    def acquire_worker(
-        self, machine_to_avoid: str = None
-    ) -> Optional[RemoteWorkerRecord]:
+    def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
         x = self.index
         while True:
             worker = self.workers[x]
@@ -482,7 +519,9 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
 
 class RemoteExecutor(BaseExecutor):
     def __init__(
-        self, workers: List[RemoteWorkerRecord], policy: RemoteWorkerSelectionPolicy
+        self,
+        workers: List[RemoteWorkerRecord],
+        policy: RemoteWorkerSelectionPolicy,
     ) -> None:
         super().__init__()
         self.workers = workers
@@ -496,9 +535,7 @@ class RemoteExecutor(BaseExecutor):
             raise RemoteExecutorException(msg)
         self.policy.register_worker_pool(self.workers)
         self.cv = threading.Condition()
-        logger.debug(
-            f'Creating {self.worker_count} local threads, one per remote worker.'
-        )
+        logger.debug(f'Creating {self.worker_count} local threads, one per remote worker.')
         self._helper_executor = fut.ThreadPoolExecutor(
             thread_name_prefix="remote_executor_helper",
             max_workers=self.worker_count,
@@ -511,9 +548,10 @@ class RemoteExecutor(BaseExecutor):
             self.heartbeat_thread,
             self.heartbeat_stop_event,
         ) = self.run_periodic_heartbeat()
+        self.already_shutdown = False
 
     @background_thread
-    def run_periodic_heartbeat(self, stop_event) -> None:
+    def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
         while not stop_event.is_set():
             time.sleep(5.0)
             logger.debug('Running periodic heartbeat code...')
@@ -521,8 +559,10 @@ class RemoteExecutor(BaseExecutor):
         logger.debug('Periodic heartbeat thread shutting down.')
 
     def heartbeat(self) -> None:
+        # Note: this is invoked on a background thread, not an
+        # executor thread.  Be careful what you do with it b/c it
+        # needs to get back and dump status again periodically.
         with self.status.lock:
-            # Dump regular progress report
             self.status.periodic_dump(self.total_bundles_submitted)
 
             # Look for bundles to reschedule via executor.submit
@@ -537,7 +577,7 @@ class RemoteExecutor(BaseExecutor):
         if (
             num_done > 2
             and num_idle_workers > 1
-            and (self.last_backup is None or (now - self.last_backup > 6.0))
+            and (self.last_backup is None or (now - self.last_backup > 9.0))
             and self.backup_lock.acquire(blocking=False)
         ):
             try:
@@ -577,21 +617,15 @@ class RemoteExecutor(BaseExecutor):
                             if start_ts is not None:
                                 runtime = now - start_ts
                                 score += runtime
-                                logger.debug(
-                                    f'score[{bundle}] => {score}  # latency boost'
-                                )
+                                logger.debug(f'score[{bundle}] => {score}  # latency boost')
 
                                 if bundle.slower_than_local_p95:
                                     score += runtime / 2
-                                    logger.debug(
-                                        f'score[{bundle}] => {score}  # >worker p95'
-                                    )
+                                    logger.debug(f'score[{bundle}] => {score}  # >worker p95')
 
                                 if bundle.slower_than_global_p95:
                                     score += runtime / 4
-                                    logger.debug(
-                                        f'score[{bundle}] => {score}  # >global p95'
-                                    )
+                                    logger.debug(f'score[{bundle}] => {score}  # >global p95')
 
                             # Prefer backups of bundles that don't
                             # have backups already.
@@ -608,9 +642,7 @@ class RemoteExecutor(BaseExecutor):
                                 f'score[{bundle}] => {score}  # {backup_count} dup backup factor'
                             )
 
-                            if score != 0 and (
-                                best_score is None or score > best_score
-                            ):
+                            if score != 0 and (best_score is None or score > best_score):
                                 bundle_to_backup = bundle
                                 assert bundle is not None
                                 assert bundle.backup_bundles is not None
@@ -635,14 +667,10 @@ class RemoteExecutor(BaseExecutor):
     def is_worker_available(self) -> bool:
         return self.policy.is_worker_available()
 
-    def acquire_worker(
-        self, machine_to_avoid: str = None
-    ) -> Optional[RemoteWorkerRecord]:
+    def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
         return self.policy.acquire_worker(machine_to_avoid)
 
-    def find_available_worker_or_block(
-        self, machine_to_avoid: str = None
-    ) -> RemoteWorkerRecord:
+    def find_available_worker_or_block(self, machine_to_avoid: str = None) -> RemoteWorkerRecord:
         with self.cv:
             while not self.is_worker_available():
                 self.cv.wait()
@@ -689,7 +717,7 @@ class RemoteExecutor(BaseExecutor):
         worker = None
         while worker is None:
             worker = self.find_available_worker_or_block(avoid_machine)
-        assert worker
+        assert worker is not None
 
         # Ok, found a worker.
         bundle.worker = worker
@@ -706,9 +734,7 @@ class RemoteExecutor(BaseExecutor):
             try:
                 return self.process_work_result(bundle)
             except Exception as e:
-                logger.warning(
-                    f'{bundle}: bundle says it\'s cancelled upfront but no results?!'
-                )
+                logger.warning(f'{bundle}: bundle says it\'s cancelled upfront but no results?!')
                 self.release_worker(bundle)
                 if is_original:
                     # Weird.  We are the original owner of this
@@ -737,9 +763,7 @@ class RemoteExecutor(BaseExecutor):
         # Send input code / data to worker machine if it's not local.
         if hostname not in machine:
             try:
-                cmd = (
-                    f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
-                )
+                cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
                 start_ts = time.time()
                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
                 run_silently(cmd)
@@ -780,15 +804,14 @@ class RemoteExecutor(BaseExecutor):
         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
         p = cmd_in_background(cmd, silent=True)
         bundle.pid = p.pid
-        logger.debug(
-            f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.'
-        )
+        logger.debug(f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.')
         return self.wait_for_process(p, bundle, 0)
 
     def wait_for_process(
-        self, p: subprocess.Popen, bundle: BundleDetails, depth: int
+        self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
     ) -> Any:
         machine = bundle.machine
+        assert p is not None
         pid = p.pid
         if depth > 3:
             logger.error(
@@ -806,9 +829,7 @@ class RemoteExecutor(BaseExecutor):
                 p.wait(timeout=0.25)
             except subprocess.TimeoutExpired:
                 if self.check_if_cancelled(bundle):
-                    logger.info(
-                        f'{bundle}: looks like another worker finished bundle...'
-                    )
+                    logger.info(f'{bundle}: looks like another worker finished bundle...')
                     break
             else:
                 logger.info(f"{bundle}: pid {pid} ({machine}) is finished!")
@@ -876,12 +897,9 @@ class RemoteExecutor(BaseExecutor):
                             break
 
                     run_silently(
-                        f'{SSH} {username}@{machine}'
-                        f' "/bin/rm -f {code_file} {result_file}"'
-                    )
-                    logger.debug(
-                        f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.'
+                        f'{SSH} {username}@{machine}' f' "/bin/rm -f {code_file} {result_file}"'
                     )
+                    logger.debug(f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.')
                 dur = bundle.end_ts - bundle.start_ts
                 self.histogram.add_item(dur)
 
@@ -916,9 +934,7 @@ class RemoteExecutor(BaseExecutor):
             # backup.
             if bundle.backup_bundles is not None:
                 for backup in bundle.backup_bundles:
-                    logger.debug(
-                        f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled'
-                    )
+                    logger.debug(f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled')
                     backup.is_cancelled.set()
 
         # This is a backup job and, by now, we have already fetched
@@ -931,10 +947,10 @@ class RemoteExecutor(BaseExecutor):
 
             # Tell the original to stop if we finished first.
             if not was_cancelled:
-                logger.debug(
-                    f'{bundle}: Notifying original {bundle.src_bundle.uuid} we beat them to it.'
-                )
-                bundle.src_bundle.is_cancelled.set()
+                orig_bundle = bundle.src_bundle
+                assert orig_bundle is not None
+                logger.debug(f'{bundle}: Notifying original {orig_bundle.uuid} we beat them to it.')
+                orig_bundle.is_cancelled.set()
         self.release_worker(bundle, was_cancelled=was_cancelled)
         return result
 
@@ -1018,7 +1034,7 @@ class RemoteExecutor(BaseExecutor):
         # they will move the result_file to this machine and let
         # the original pick them up and unpickle them.
 
-    def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> fut.Future:
+    def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> Optional[fut.Future]:
         is_original = bundle.src_bundle is None
         bundle.worker = None
         avoid_last_machine = bundle.machine
@@ -1051,18 +1067,23 @@ class RemoteExecutor(BaseExecutor):
 
     @overrides
     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
+        if self.already_shutdown:
+            raise Exception('Submitted work after shutdown.')
         pickle = make_cloud_pickle(function, *args, **kwargs)
         bundle = self.create_original_bundle(pickle, function.__name__)
         self.total_bundles_submitted += 1
         return self._helper_executor.submit(self.launch, bundle)
 
     @overrides
-    def shutdown(self, wait=True) -> None:
-        logging.debug(f'Shutting down RemoteExecutor {self.title}')
-        self.heartbeat_stop_event.set()
-        self.heartbeat_thread.join()
-        self._helper_executor.shutdown(wait)
-        print(self.histogram)
+    def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
+        if not self.already_shutdown:
+            logging.debug(f'Shutting down RemoteExecutor {self.title}')
+            self.heartbeat_stop_event.set()
+            self.heartbeat_thread.join()
+            self._helper_executor.shutdown(wait)
+            if not quiet:
+                print(self.histogram.__repr__(label_formatter='%ds'))
+            self.already_shutdown = True
 
 
 @singleton
@@ -1075,9 +1096,7 @@ class DefaultExecutors(object):
     def ping(self, host) -> bool:
         logger.debug(f'RUN> ping -c 1 {host}')
         try:
-            x = cmd_with_timeout(
-                f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0
-            )
+            x = cmd_with_timeout(f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0)
             return x == 0
         except Exception:
             return False
@@ -1102,7 +1121,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username='scott',
                         machine='cheetah.house',
-                        weight=34,
+                        weight=30,
                         count=6,
                     ),
                 )
@@ -1132,7 +1151,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username='scott',
                         machine='puma.cabin',
-                        weight=25,
+                        weight=30,
                         count=6,
                     ),
                 )
@@ -1142,7 +1161,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username='scott',
                         machine='backup.house',
-                        weight=7,
+                        weight=8,
                         count=2,
                     ),
                 )
@@ -1160,11 +1179,11 @@ class DefaultExecutors(object):
 
     def shutdown(self) -> None:
         if self.thread_executor is not None:
-            self.thread_executor.shutdown()
+            self.thread_executor.shutdown(wait=True, quiet=True)
             self.thread_executor = None
         if self.process_executor is not None:
-            self.process_executor.shutdown()
+            self.process_executor.shutdown(wait=True, quiet=True)
             self.process_executor = None
         if self.remote_executor is not None:
-            self.remote_executor.shutdown()
+            self.remote_executor.shutdown(wait=True, quiet=True)
             self.remote_executor = None