More cleanup.
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3
4 """Defines three executors: a thread executor for doing work using a
5 threadpool, a process executor for doing work in other processes on
6 the same machine and a remote executor for farming out work to other
7 machines.
8
9 Also defines DefaultExecutors which is a container for references to
10 global executors / worker pools with automatic shutdown semantics."""
11
12 from __future__ import annotations
13
14 import concurrent.futures as fut
15 import logging
16 import os
17 import platform
18 import random
19 import subprocess
20 import threading
21 import time
22 import warnings
23 from abc import ABC, abstractmethod
24 from collections import defaultdict
25 from dataclasses import dataclass
26 from typing import Any, Callable, Dict, List, Optional, Set
27
28 import cloudpickle  # type: ignore
29 import numpy
30 from overrides import overrides
31
32 import argparse_utils
33 import config
34 import histogram as hist
35 from ansi import bg, fg, reset, underline
36 from decorator_utils import singleton
37 from exec_utils import cmd_in_background, cmd_with_timeout, run_silently
38 from thread_utils import background_thread
39
40 logger = logging.getLogger(__name__)
41
42 parser = config.add_commandline_args(
43     f"Executors ({__file__})", "Args related to processing executors."
44 )
45 parser.add_argument(
46     '--executors_threadpool_size',
47     type=int,
48     metavar='#THREADS',
49     help='Number of threads in the default threadpool, leave unset for default',
50     default=None,
51 )
52 parser.add_argument(
53     '--executors_processpool_size',
54     type=int,
55     metavar='#PROCESSES',
56     help='Number of processes in the default processpool, leave unset for default',
57     default=None,
58 )
59 parser.add_argument(
60     '--executors_schedule_remote_backups',
61     default=True,
62     action=argparse_utils.ActionNoYes,
63     help='Should we schedule duplicative backup work if a remote bundle is slow',
64 )
65 parser.add_argument(
66     '--executors_max_bundle_failures',
67     type=int,
68     default=3,
69     metavar='#FAILURES',
70     help='Maximum number of failures before giving up on a bundle',
71 )
72
73 SSH = '/usr/bin/ssh -oForwardX11=no'
74 SCP = '/usr/bin/scp -C'
75
76
77 def make_cloud_pickle(fun, *args, **kwargs):
78     logger.debug("Making cloudpickled bundle at %s", fun.__name__)
79     return cloudpickle.dumps((fun, args, kwargs))
80
81
82 class BaseExecutor(ABC):
83     """The base executor interface definition."""
84
85     def __init__(self, *, title=''):
86         self.title = title
87         self.histogram = hist.SimpleHistogram(
88             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
89         )
90         self.task_count = 0
91
92     @abstractmethod
93     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
94         pass
95
96     @abstractmethod
97     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
98         pass
99
100     def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
101         """Shutdown the executor and return True if the executor is idle
102         (i.e. there are no pending or active tasks).  Return False
103         otherwise.  Note: this should only be called by the launcher
104         process.
105
106         """
107         if self.task_count == 0:
108             self.shutdown(wait=True, quiet=quiet)
109             return True
110         return False
111
112     def adjust_task_count(self, delta: int) -> None:
113         """Change the task count.  Note: do not call this method from a
114         worker, it should only be called by the launcher process /
115         thread / machine.
116
117         """
118         self.task_count += delta
119         logger.debug('Adjusted task count by %d to %d.', delta, self.task_count)
120
121     def get_task_count(self) -> int:
122         """Change the task count.  Note: do not call this method from a
123         worker, it should only be called by the launcher process /
124         thread / machine.
125
126         """
127         return self.task_count
128
129
130 class ThreadExecutor(BaseExecutor):
131     """A threadpool executor instance."""
132
133     def __init__(self, max_workers: Optional[int] = None):
134         super().__init__()
135         workers = None
136         if max_workers is not None:
137             workers = max_workers
138         elif 'executors_threadpool_size' in config.config:
139             workers = config.config['executors_threadpool_size']
140         logger.debug('Creating threadpool executor with %d workers', workers)
141         self._thread_pool_executor = fut.ThreadPoolExecutor(
142             max_workers=workers, thread_name_prefix="thread_executor_helper"
143         )
144         self.already_shutdown = False
145
146     # This is run on a different thread; do not adjust task count here.
147     @staticmethod
148     def run_local_bundle(fun, *args, **kwargs):
149         logger.debug("Running local bundle at %s", fun.__name__)
150         result = fun(*args, **kwargs)
151         return result
152
153     @overrides
154     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
155         if self.already_shutdown:
156             raise Exception('Submitted work after shutdown.')
157         self.adjust_task_count(+1)
158         newargs = []
159         newargs.append(function)
160         for arg in args:
161             newargs.append(arg)
162         start = time.time()
163         result = self._thread_pool_executor.submit(
164             ThreadExecutor.run_local_bundle, *newargs, **kwargs
165         )
166         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
167         result.add_done_callback(lambda _: self.adjust_task_count(-1))
168         return result
169
170     @overrides
171     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
172         if not self.already_shutdown:
173             logger.debug('Shutting down threadpool executor %s', self.title)
174             self._thread_pool_executor.shutdown(wait)
175             if not quiet:
176                 print(self.histogram.__repr__(label_formatter='%ds'))
177             self.already_shutdown = True
178
179
180 class ProcessExecutor(BaseExecutor):
181     """A processpool executor."""
182
183     def __init__(self, max_workers=None):
184         super().__init__()
185         workers = None
186         if max_workers is not None:
187             workers = max_workers
188         elif 'executors_processpool_size' in config.config:
189             workers = config.config['executors_processpool_size']
190         logger.debug('Creating processpool executor with %d workers.', workers)
191         self._process_executor = fut.ProcessPoolExecutor(
192             max_workers=workers,
193         )
194         self.already_shutdown = False
195
196     # This is run in another process; do not adjust task count here.
197     @staticmethod
198     def run_cloud_pickle(pickle):
199         fun, args, kwargs = cloudpickle.loads(pickle)
200         logger.debug("Running pickled bundle at %s", fun.__name__)
201         result = fun(*args, **kwargs)
202         return result
203
204     @overrides
205     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
206         if self.already_shutdown:
207             raise Exception('Submitted work after shutdown.')
208         start = time.time()
209         self.adjust_task_count(+1)
210         pickle = make_cloud_pickle(function, *args, **kwargs)
211         result = self._process_executor.submit(ProcessExecutor.run_cloud_pickle, pickle)
212         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
213         result.add_done_callback(lambda _: self.adjust_task_count(-1))
214         return result
215
216     @overrides
217     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
218         if not self.already_shutdown:
219             logger.debug('Shutting down processpool executor %s', self.title)
220             self._process_executor.shutdown(wait)
221             if not quiet:
222                 print(self.histogram.__repr__(label_formatter='%ds'))
223             self.already_shutdown = True
224
225     def __getstate__(self):
226         state = self.__dict__.copy()
227         state['_process_executor'] = None
228         return state
229
230
231 class RemoteExecutorException(Exception):
232     """Thrown when a bundle cannot be executed despite several retries."""
233
234     pass
235
236
237 @dataclass
238 class RemoteWorkerRecord:
239     """A record of info about a remote worker."""
240
241     username: str
242     machine: str
243     weight: int
244     count: int
245
246     def __hash__(self):
247         return hash((self.username, self.machine))
248
249     def __repr__(self):
250         return f'{self.username}@{self.machine}'
251
252
253 @dataclass
254 class BundleDetails:
255     """All info necessary to define some unit of work that needs to be
256     done, where it is being run, its state, whether it is an original
257     bundle of a backup bundle, how many times it has failed, etc...
258
259     """
260
261     pickled_code: bytes
262     uuid: str
263     fname: str
264     worker: Optional[RemoteWorkerRecord]
265     username: Optional[str]
266     machine: Optional[str]
267     hostname: str
268     code_file: str
269     result_file: str
270     pid: int
271     start_ts: float
272     end_ts: float
273     slower_than_local_p95: bool
274     slower_than_global_p95: bool
275     src_bundle: Optional[BundleDetails]
276     is_cancelled: threading.Event
277     was_cancelled: bool
278     backup_bundles: Optional[List[BundleDetails]]
279     failure_count: int
280
281     def __repr__(self):
282         uuid = self.uuid
283         if uuid[-9:-2] == '_backup':
284             uuid = uuid[:-9]
285             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
286         else:
287             suffix = uuid[-6:]
288
289         colorz = [
290             fg('violet red'),
291             fg('red'),
292             fg('orange'),
293             fg('peach orange'),
294             fg('yellow'),
295             fg('marigold yellow'),
296             fg('green yellow'),
297             fg('tea green'),
298             fg('cornflower blue'),
299             fg('turquoise blue'),
300             fg('tropical blue'),
301             fg('lavender purple'),
302             fg('medium purple'),
303         ]
304         c = colorz[int(uuid[-2:], 16) % len(colorz)]
305         fname = self.fname if self.fname is not None else 'nofname'
306         machine = self.machine if self.machine is not None else 'nomachine'
307         return f'{c}{suffix}/{fname}/{machine}{reset()}'
308
309
310 class RemoteExecutorStatus:
311     """A status 'scoreboard' for a remote executor."""
312
313     def __init__(self, total_worker_count: int) -> None:
314         self.worker_count: int = total_worker_count
315         self.known_workers: Set[RemoteWorkerRecord] = set()
316         self.start_time: float = time.time()
317         self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
318         self.end_per_bundle: Dict[str, float] = defaultdict(float)
319         self.finished_bundle_timings_per_worker: Dict[RemoteWorkerRecord, List[float]] = {}
320         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
321         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
322         self.finished_bundle_timings: List[float] = []
323         self.last_periodic_dump: Optional[float] = None
324         self.total_bundles_submitted: int = 0
325
326         # Protects reads and modification using self.  Also used
327         # as a memory fence for modifications to bundle.
328         self.lock: threading.Lock = threading.Lock()
329
330     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
331         with self.lock:
332             self.record_acquire_worker_already_locked(worker, uuid)
333
334     def record_acquire_worker_already_locked(self, worker: RemoteWorkerRecord, uuid: str) -> None:
335         assert self.lock.locked()
336         self.known_workers.add(worker)
337         self.start_per_bundle[uuid] = None
338         x = self.in_flight_bundles_by_worker.get(worker, set())
339         x.add(uuid)
340         self.in_flight_bundles_by_worker[worker] = x
341
342     def record_bundle_details(self, details: BundleDetails) -> None:
343         with self.lock:
344             self.record_bundle_details_already_locked(details)
345
346     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
347         assert self.lock.locked()
348         self.bundle_details_by_uuid[details.uuid] = details
349
350     def record_release_worker(
351         self,
352         worker: RemoteWorkerRecord,
353         uuid: str,
354         was_cancelled: bool,
355     ) -> None:
356         with self.lock:
357             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
358
359     def record_release_worker_already_locked(
360         self,
361         worker: RemoteWorkerRecord,
362         uuid: str,
363         was_cancelled: bool,
364     ) -> None:
365         assert self.lock.locked()
366         ts = time.time()
367         self.end_per_bundle[uuid] = ts
368         self.in_flight_bundles_by_worker[worker].remove(uuid)
369         if not was_cancelled:
370             start = self.start_per_bundle[uuid]
371             assert start is not None
372             bundle_latency = ts - start
373             x = self.finished_bundle_timings_per_worker.get(worker, [])
374             x.append(bundle_latency)
375             self.finished_bundle_timings_per_worker[worker] = x
376             self.finished_bundle_timings.append(bundle_latency)
377
378     def record_processing_began(self, uuid: str):
379         with self.lock:
380             self.start_per_bundle[uuid] = time.time()
381
382     def total_in_flight(self) -> int:
383         assert self.lock.locked()
384         total_in_flight = 0
385         for worker in self.known_workers:
386             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
387         return total_in_flight
388
389     def total_idle(self) -> int:
390         assert self.lock.locked()
391         return self.worker_count - self.total_in_flight()
392
393     def __repr__(self):
394         assert self.lock.locked()
395         ts = time.time()
396         total_finished = len(self.finished_bundle_timings)
397         total_in_flight = self.total_in_flight()
398         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
399         qall = None
400         if len(self.finished_bundle_timings) > 1:
401             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
402             ret += (
403                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, total={ts-self.start_time:.1f}s, '
404                 f'✅={total_finished}/{self.total_bundles_submitted}, '
405                 f'💻n={total_in_flight}/{self.worker_count}\n'
406             )
407         else:
408             ret += (
409                 f'⏱={ts-self.start_time:.1f}s, '
410                 f'✅={total_finished}/{self.total_bundles_submitted}, '
411                 f'💻n={total_in_flight}/{self.worker_count}\n'
412             )
413
414         for worker in self.known_workers:
415             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
416             timings = self.finished_bundle_timings_per_worker.get(worker, [])
417             count = len(timings)
418             qworker = None
419             if count > 1:
420                 qworker = numpy.quantile(timings, [0.5, 0.95])
421                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
422             else:
423                 ret += '\n'
424             if count > 0:
425                 ret += f'    ...finished {count} total bundle(s) so far\n'
426             in_flight = len(self.in_flight_bundles_by_worker[worker])
427             if in_flight > 0:
428                 ret += f'    ...{in_flight} bundles currently in flight:\n'
429                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
430                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
431                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
432                     if self.start_per_bundle[bundle_uuid] is not None:
433                         sec = ts - self.start_per_bundle[bundle_uuid]
434                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
435                     else:
436                         ret += f'       {details} setting up / copying data...'
437                         sec = 0.0
438
439                     if qworker is not None:
440                         if sec > qworker[1]:
441                             ret += f'{bg("red")}>💻p95{reset()} '
442                             if details is not None:
443                                 details.slower_than_local_p95 = True
444                         else:
445                             if details is not None:
446                                 details.slower_than_local_p95 = False
447
448                     if qall is not None:
449                         if sec > qall[1]:
450                             ret += f'{bg("red")}>∀p95{reset()} '
451                             if details is not None:
452                                 details.slower_than_global_p95 = True
453                         else:
454                             details.slower_than_global_p95 = False
455                     ret += '\n'
456         return ret
457
458     def periodic_dump(self, total_bundles_submitted: int) -> None:
459         assert self.lock.locked()
460         self.total_bundles_submitted = total_bundles_submitted
461         ts = time.time()
462         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
463             print(self)
464             self.last_periodic_dump = ts
465
466
467 class RemoteWorkerSelectionPolicy(ABC):
468     """A policy for selecting a remote worker base class."""
469
470     def __init__(self):
471         self.workers: Optional[List[RemoteWorkerRecord]] = None
472
473     def register_worker_pool(self, workers: List[RemoteWorkerRecord]):
474         self.workers = workers
475
476     @abstractmethod
477     def is_worker_available(self) -> bool:
478         pass
479
480     @abstractmethod
481     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
482         pass
483
484
485 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
486     """A remote worker selector that uses weighted RNG."""
487
488     @overrides
489     def is_worker_available(self) -> bool:
490         if self.workers:
491             for worker in self.workers:
492                 if worker.count > 0:
493                     return True
494         return False
495
496     @overrides
497     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
498         grabbag = []
499         if self.workers:
500             for worker in self.workers:
501                 if worker.machine != machine_to_avoid:
502                     if worker.count > 0:
503                         for _ in range(worker.count * worker.weight):
504                             grabbag.append(worker)
505
506         if len(grabbag) == 0:
507             logger.debug('There are no available workers that avoid %s', machine_to_avoid)
508             if self.workers:
509                 for worker in self.workers:
510                     if worker.count > 0:
511                         for _ in range(worker.count * worker.weight):
512                             grabbag.append(worker)
513
514         if len(grabbag) == 0:
515             logger.warning('There are no available workers?!')
516             return None
517
518         worker = random.sample(grabbag, 1)[0]
519         assert worker.count > 0
520         worker.count -= 1
521         logger.debug('Selected worker %s', worker)
522         return worker
523
524
525 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
526     """A remote worker selector that just round robins."""
527
528     def __init__(self) -> None:
529         super().__init__()
530         self.index = 0
531
532     @overrides
533     def is_worker_available(self) -> bool:
534         if self.workers:
535             for worker in self.workers:
536                 if worker.count > 0:
537                     return True
538         return False
539
540     @overrides
541     def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
542         if self.workers:
543             x = self.index
544             while True:
545                 worker = self.workers[x]
546                 if worker.count > 0:
547                     worker.count -= 1
548                     x += 1
549                     if x >= len(self.workers):
550                         x = 0
551                     self.index = x
552                     logger.debug('Selected worker %s', worker)
553                     return worker
554                 x += 1
555                 if x >= len(self.workers):
556                     x = 0
557                 if x == self.index:
558                     logger.warning('Unexpectedly could not find a worker, retrying...')
559                     return None
560         return None
561
562
563 class RemoteExecutor(BaseExecutor):
564     """A remote work executor."""
565
566     def __init__(
567         self,
568         workers: List[RemoteWorkerRecord],
569         policy: RemoteWorkerSelectionPolicy,
570     ) -> None:
571         super().__init__()
572         self.workers = workers
573         self.policy = policy
574         self.worker_count = 0
575         for worker in self.workers:
576             self.worker_count += worker.count
577         if self.worker_count <= 0:
578             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
579             logger.critical(msg)
580             raise RemoteExecutorException(msg)
581         self.policy.register_worker_pool(self.workers)
582         self.cv = threading.Condition()
583         logger.debug('Creating %d local threads, one per remote worker.', self.worker_count)
584         self._helper_executor = fut.ThreadPoolExecutor(
585             thread_name_prefix="remote_executor_helper",
586             max_workers=self.worker_count,
587         )
588         self.status = RemoteExecutorStatus(self.worker_count)
589         self.total_bundles_submitted = 0
590         self.backup_lock = threading.Lock()
591         self.last_backup = None
592         (
593             self.heartbeat_thread,
594             self.heartbeat_stop_event,
595         ) = self.run_periodic_heartbeat()
596         self.already_shutdown = False
597
598     @background_thread
599     def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
600         while not stop_event.is_set():
601             time.sleep(5.0)
602             logger.debug('Running periodic heartbeat code...')
603             self.heartbeat()
604         logger.debug('Periodic heartbeat thread shutting down.')
605
606     def heartbeat(self) -> None:
607         # Note: this is invoked on a background thread, not an
608         # executor thread.  Be careful what you do with it b/c it
609         # needs to get back and dump status again periodically.
610         with self.status.lock:
611             self.status.periodic_dump(self.total_bundles_submitted)
612
613             # Look for bundles to reschedule via executor.submit
614             if config.config['executors_schedule_remote_backups']:
615                 self.maybe_schedule_backup_bundles()
616
617     def maybe_schedule_backup_bundles(self):
618         assert self.status.lock.locked()
619         num_done = len(self.status.finished_bundle_timings)
620         num_idle_workers = self.worker_count - self.task_count
621         now = time.time()
622         if (
623             num_done > 2
624             and num_idle_workers > 1
625             and (self.last_backup is None or (now - self.last_backup > 9.0))
626             and self.backup_lock.acquire(blocking=False)
627         ):
628             try:
629                 assert self.backup_lock.locked()
630
631                 bundle_to_backup = None
632                 best_score = None
633                 for (
634                     worker,
635                     bundle_uuids,
636                 ) in self.status.in_flight_bundles_by_worker.items():
637
638                     # Prefer to schedule backups of bundles running on
639                     # slower machines.
640                     base_score = 0
641                     for record in self.workers:
642                         if worker.machine == record.machine:
643                             base_score = float(record.weight)
644                             base_score = 1.0 / base_score
645                             base_score *= 200.0
646                             base_score = int(base_score)
647                             break
648
649                     for uuid in bundle_uuids:
650                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
651                         if (
652                             bundle is not None
653                             and bundle.src_bundle is None
654                             and bundle.backup_bundles is not None
655                         ):
656                             score = base_score
657
658                             # Schedule backups of bundles running
659                             # longer; especially those that are
660                             # unexpectedly slow.
661                             start_ts = self.status.start_per_bundle[uuid]
662                             if start_ts is not None:
663                                 runtime = now - start_ts
664                                 score += runtime
665                                 logger.debug('score[%s] => %.1f  # latency boost', bundle, score)
666
667                                 if bundle.slower_than_local_p95:
668                                     score += runtime / 2
669                                     logger.debug('score[%s] => %.1f  # >worker p95', bundle, score)
670
671                                 if bundle.slower_than_global_p95:
672                                     score += runtime / 4
673                                     logger.debug('score[%s] => %.1f  # >global p95', bundle, score)
674
675                             # Prefer backups of bundles that don't
676                             # have backups already.
677                             backup_count = len(bundle.backup_bundles)
678                             if backup_count == 0:
679                                 score *= 2
680                             elif backup_count == 1:
681                                 score /= 2
682                             elif backup_count == 2:
683                                 score /= 8
684                             else:
685                                 score = 0
686                             logger.debug(
687                                 'score[%s] => %.1f  # {backup_count} dup backup factor',
688                                 bundle,
689                                 score,
690                             )
691
692                             if score != 0 and (best_score is None or score > best_score):
693                                 bundle_to_backup = bundle
694                                 assert bundle is not None
695                                 assert bundle.backup_bundles is not None
696                                 assert bundle.src_bundle is None
697                                 best_score = score
698
699                 # Note: this is all still happening on the heartbeat
700                 # runner thread.  That's ok because
701                 # schedule_backup_for_bundle uses the executor to
702                 # submit the bundle again which will cause it to be
703                 # picked up by a worker thread and allow this thread
704                 # to return to run future heartbeats.
705                 if bundle_to_backup is not None:
706                     self.last_backup = now
707                     logger.info(
708                         '=====> SCHEDULING BACKUP %s (score=%.1f}) <=====',
709                         bundle_to_backup,
710                         best_score,
711                     )
712                     self.schedule_backup_for_bundle(bundle_to_backup)
713             finally:
714                 self.backup_lock.release()
715
716     def is_worker_available(self) -> bool:
717         return self.policy.is_worker_available()
718
719     def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
720         return self.policy.acquire_worker(machine_to_avoid)
721
722     def find_available_worker_or_block(self, machine_to_avoid: str = None) -> RemoteWorkerRecord:
723         with self.cv:
724             while not self.is_worker_available():
725                 self.cv.wait()
726             worker = self.acquire_worker(machine_to_avoid)
727             if worker is not None:
728                 return worker
729         msg = "We should never reach this point in the code"
730         logger.critical(msg)
731         raise Exception(msg)
732
733     def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
734         worker = bundle.worker
735         assert worker is not None
736         logger.debug('Released worker %s', worker)
737         self.status.record_release_worker(
738             worker,
739             bundle.uuid,
740             was_cancelled,
741         )
742         with self.cv:
743             worker.count += 1
744             self.cv.notify()
745         self.adjust_task_count(-1)
746
747     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
748         with self.status.lock:
749             if bundle.is_cancelled.wait(timeout=0.0):
750                 logger.debug('Bundle %s is cancelled, bail out.', bundle.uuid)
751                 bundle.was_cancelled = True
752                 return True
753         return False
754
755     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
756         """Find a worker for bundle or block until one is available."""
757         self.adjust_task_count(+1)
758         uuid = bundle.uuid
759         hostname = bundle.hostname
760         avoid_machine = override_avoid_machine
761         is_original = bundle.src_bundle is None
762
763         # Try not to schedule a backup on the same host as the original.
764         if avoid_machine is None and bundle.src_bundle is not None:
765             avoid_machine = bundle.src_bundle.machine
766         worker = None
767         while worker is None:
768             worker = self.find_available_worker_or_block(avoid_machine)
769         assert worker is not None
770
771         # Ok, found a worker.
772         bundle.worker = worker
773         machine = bundle.machine = worker.machine
774         username = bundle.username = worker.username
775         self.status.record_acquire_worker(worker, uuid)
776         logger.debug('%s: Running bundle on %s...', bundle, worker)
777
778         # Before we do any work, make sure the bundle is still viable.
779         # It may have been some time between when it was submitted and
780         # now due to lack of worker availability and someone else may
781         # have already finished it.
782         if self.check_if_cancelled(bundle):
783             try:
784                 return self.process_work_result(bundle)
785             except Exception as e:
786                 logger.warning('%s: bundle says it\'s cancelled upfront but no results?!', bundle)
787                 self.release_worker(bundle)
788                 if is_original:
789                     # Weird.  We are the original owner of this
790                     # bundle.  For it to have been cancelled, a backup
791                     # must have already started and completed before
792                     # we even for started.  Moreover, the backup says
793                     # it is done but we can't find the results it
794                     # should have copied over.  Reschedule the whole
795                     # thing.
796                     logger.exception(e)
797                     logger.error(
798                         '%s: We are the original owner thread and yet there are '
799                         'no results for this bundle.  This is unexpected and bad.',
800                         bundle,
801                     )
802                     return self.emergency_retry_nasty_bundle(bundle)
803                 else:
804                     # Expected(?).  We're a backup and our bundle is
805                     # cancelled before we even got started.  Something
806                     # went bad in process_work_result (I acutually don't
807                     # see what?) but probably not worth worrying
808                     # about.  Let the original thread worry about
809                     # either finding the results or complaining about
810                     # it.
811                     return None
812
813         # Send input code / data to worker machine if it's not local.
814         if hostname not in machine:
815             try:
816                 cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
817                 start_ts = time.time()
818                 logger.info("%s: Copying work to %s via %s.", bundle, worker, cmd)
819                 run_silently(cmd)
820                 xfer_latency = time.time() - start_ts
821                 logger.debug("%s: Copying to %s took %.1fs.", bundle, worker, xfer_latency)
822             except Exception as e:
823                 self.release_worker(bundle)
824                 if is_original:
825                     # Weird.  We tried to copy the code to the worker and it failed...
826                     # And we're the original bundle.  We have to retry.
827                     logger.exception(e)
828                     logger.error(
829                         "%s: Failed to send instructions to the worker machine?! "
830                         "This is not expected; we\'re the original bundle so this shouldn\'t "
831                         "be a race condition.  Attempting an emergency retry...",
832                         bundle,
833                     )
834                     return self.emergency_retry_nasty_bundle(bundle)
835                 else:
836                     # This is actually expected; we're a backup.
837                     # There's a race condition where someone else
838                     # already finished the work and removed the source
839                     # code file before we could copy it.  No biggie.
840                     logger.warning(
841                         '%s: Failed to send instructions to the worker machine... '
842                         'We\'re a backup and this may be caused by the original (or '
843                         'some other backup) already finishing this work.  Ignoring.',
844                         bundle,
845                     )
846                     return None
847
848         # Kick off the work.  Note that if this fails we let
849         # wait_for_process deal with it.
850         self.status.record_processing_began(uuid)
851         cmd = (
852             f'{SSH} {bundle.username}@{bundle.machine} '
853             f'"source py38-venv/bin/activate &&'
854             f' /home/scott/lib/python_modules/remote_worker.py'
855             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
856         )
857         logger.debug('%s: Executing %s in the background to kick off work...', bundle, cmd)
858         p = cmd_in_background(cmd, silent=True)
859         bundle.pid = p.pid
860         logger.debug('%s: Local ssh process pid=%d; remote worker is %s.', bundle, p.pid, machine)
861         return self.wait_for_process(p, bundle, 0)
862
863     def wait_for_process(
864         self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
865     ) -> Any:
866         machine = bundle.machine
867         assert p is not None
868         pid = p.pid
869         if depth > 3:
870             logger.error(
871                 "I've gotten repeated errors waiting on this bundle; giving up on pid=%d", pid
872             )
873             p.terminate()
874             self.release_worker(bundle)
875             return self.emergency_retry_nasty_bundle(bundle)
876
877         # Spin until either the ssh job we scheduled finishes the
878         # bundle or some backup worker signals that they finished it
879         # before we could.
880         while True:
881             try:
882                 p.wait(timeout=0.25)
883             except subprocess.TimeoutExpired:
884                 if self.check_if_cancelled(bundle):
885                     logger.info('%s: looks like another worker finished bundle...', bundle)
886                     break
887             else:
888                 logger.info("%s: pid %d (%s) is finished!", bundle, pid, machine)
889                 p = None
890                 break
891
892         # If we get here we believe the bundle is done; either the ssh
893         # subprocess finished (hopefully successfully) or we noticed
894         # that some other worker seems to have completed the bundle
895         # and we're bailing out.
896         try:
897             ret = self.process_work_result(bundle)
898             if ret is not None and p is not None:
899                 p.terminate()
900             return ret
901
902         # Something went wrong; e.g. we could not copy the results
903         # back, cleanup after ourselves on the remote machine, or
904         # unpickle the results we got from the remove machine.  If we
905         # still have an active ssh subprocess, keep waiting on it.
906         # Otherwise, time for an emergency reschedule.
907         except Exception as e:
908             logger.exception(e)
909             logger.error('%s: Something unexpected just happened...', bundle)
910             if p is not None:
911                 logger.warning(
912                     "%s: Failed to wrap up \"done\" bundle, re-waiting on active ssh.", bundle
913                 )
914                 return self.wait_for_process(p, bundle, depth + 1)
915             else:
916                 self.release_worker(bundle)
917                 return self.emergency_retry_nasty_bundle(bundle)
918
919     def process_work_result(self, bundle: BundleDetails) -> Any:
920         with self.status.lock:
921             is_original = bundle.src_bundle is None
922             was_cancelled = bundle.was_cancelled
923             username = bundle.username
924             machine = bundle.machine
925             result_file = bundle.result_file
926             code_file = bundle.code_file
927
928             # Whether original or backup, if we finished first we must
929             # fetch the results if the computation happened on a
930             # remote machine.
931             bundle.end_ts = time.time()
932             if not was_cancelled:
933                 assert bundle.machine is not None
934                 if bundle.hostname not in bundle.machine:
935                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
936                     logger.info(
937                         "%s: Fetching results back from %s@%s via %s",
938                         bundle,
939                         username,
940                         machine,
941                         cmd,
942                     )
943
944                     # If either of these throw they are handled in
945                     # wait_for_process.
946                     attempts = 0
947                     while True:
948                         try:
949                             run_silently(cmd)
950                         except Exception as e:
951                             attempts += 1
952                             if attempts >= 3:
953                                 raise e
954                         else:
955                             break
956
957                     run_silently(
958                         f'{SSH} {username}@{machine}' f' "/bin/rm -f {code_file} {result_file}"'
959                     )
960                     logger.debug('Fetching results back took %.2fs', time.time() - bundle.end_ts)
961                 dur = bundle.end_ts - bundle.start_ts
962                 self.histogram.add_item(dur)
963
964         # Only the original worker should unpickle the file contents
965         # though since it's the only one whose result matters.  The
966         # original is also the only job that may delete result_file
967         # from disk.  Note that the original may have been cancelled
968         # if one of the backups finished first; it still must read the
969         # result from disk.
970         if is_original:
971             logger.debug("%s: Unpickling %s.", bundle, result_file)
972             try:
973                 with open(result_file, 'rb') as rb:
974                     serialized = rb.read()
975                 result = cloudpickle.loads(serialized)
976             except Exception as e:
977                 logger.exception(e)
978                 logger.error('Failed to load %s... this is bad news.', result_file)
979                 self.release_worker(bundle)
980
981                 # Re-raise the exception; the code in wait_for_process may
982                 # decide to emergency_retry_nasty_bundle here.
983                 raise e
984             logger.debug('Removing local (master) %s and %s.', code_file, result_file)
985             os.remove(result_file)
986             os.remove(code_file)
987
988             # Notify any backups that the original is done so they
989             # should stop ASAP.  Do this whether or not we
990             # finished first since there could be more than one
991             # backup.
992             if bundle.backup_bundles is not None:
993                 for backup in bundle.backup_bundles:
994                     logger.debug(
995                         '%s: Notifying backup %s that it\'s cancelled', bundle, backup.uuid
996                     )
997                     backup.is_cancelled.set()
998
999         # This is a backup job and, by now, we have already fetched
1000         # the bundle results.
1001         else:
1002             # Backup results don't matter, they just need to leave the
1003             # result file in the right place for their originals to
1004             # read/unpickle later.
1005             result = None
1006
1007             # Tell the original to stop if we finished first.
1008             if not was_cancelled:
1009                 orig_bundle = bundle.src_bundle
1010                 assert orig_bundle is not None
1011                 logger.debug(
1012                     '%s: Notifying original %s we beat them to it.', bundle, orig_bundle.uuid
1013                 )
1014                 orig_bundle.is_cancelled.set()
1015         self.release_worker(bundle, was_cancelled=was_cancelled)
1016         return result
1017
1018     def create_original_bundle(self, pickle, fname: str):
1019         from string_utils import generate_uuid
1020
1021         uuid = generate_uuid(omit_dashes=True)
1022         code_file = f'/tmp/{uuid}.code.bin'
1023         result_file = f'/tmp/{uuid}.result.bin'
1024
1025         logger.debug('Writing pickled code to %s', code_file)
1026         with open(code_file, 'wb') as wb:
1027             wb.write(pickle)
1028
1029         bundle = BundleDetails(
1030             pickled_code=pickle,
1031             uuid=uuid,
1032             fname=fname,
1033             worker=None,
1034             username=None,
1035             machine=None,
1036             hostname=platform.node(),
1037             code_file=code_file,
1038             result_file=result_file,
1039             pid=0,
1040             start_ts=time.time(),
1041             end_ts=0.0,
1042             slower_than_local_p95=False,
1043             slower_than_global_p95=False,
1044             src_bundle=None,
1045             is_cancelled=threading.Event(),
1046             was_cancelled=False,
1047             backup_bundles=[],
1048             failure_count=0,
1049         )
1050         self.status.record_bundle_details(bundle)
1051         logger.debug('%s: Created an original bundle', bundle)
1052         return bundle
1053
1054     def create_backup_bundle(self, src_bundle: BundleDetails):
1055         assert src_bundle.backup_bundles is not None
1056         n = len(src_bundle.backup_bundles)
1057         uuid = src_bundle.uuid + f'_backup#{n}'
1058
1059         backup_bundle = BundleDetails(
1060             pickled_code=src_bundle.pickled_code,
1061             uuid=uuid,
1062             fname=src_bundle.fname,
1063             worker=None,
1064             username=None,
1065             machine=None,
1066             hostname=src_bundle.hostname,
1067             code_file=src_bundle.code_file,
1068             result_file=src_bundle.result_file,
1069             pid=0,
1070             start_ts=time.time(),
1071             end_ts=0.0,
1072             slower_than_local_p95=False,
1073             slower_than_global_p95=False,
1074             src_bundle=src_bundle,
1075             is_cancelled=threading.Event(),
1076             was_cancelled=False,
1077             backup_bundles=None,  # backup backups not allowed
1078             failure_count=0,
1079         )
1080         src_bundle.backup_bundles.append(backup_bundle)
1081         self.status.record_bundle_details_already_locked(backup_bundle)
1082         logger.debug('%s: Created a backup bundle', backup_bundle)
1083         return backup_bundle
1084
1085     def schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1086         assert self.status.lock.locked()
1087         assert src_bundle is not None
1088         backup_bundle = self.create_backup_bundle(src_bundle)
1089         logger.debug(
1090             '%s/%s: Scheduling backup for execution...', backup_bundle.uuid, backup_bundle.fname
1091         )
1092         self._helper_executor.submit(self.launch, backup_bundle)
1093
1094         # Results from backups don't matter; if they finish first
1095         # they will move the result_file to this machine and let
1096         # the original pick them up and unpickle them.
1097
1098     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> Optional[fut.Future]:
1099         is_original = bundle.src_bundle is None
1100         bundle.worker = None
1101         avoid_last_machine = bundle.machine
1102         bundle.machine = None
1103         bundle.username = None
1104         bundle.failure_count += 1
1105         if is_original:
1106             retry_limit = 3
1107         else:
1108             retry_limit = 2
1109
1110         if bundle.failure_count > retry_limit:
1111             logger.error(
1112                 '%s: Tried this bundle too many times already (%dx); giving up.',
1113                 bundle,
1114                 retry_limit,
1115             )
1116             if is_original:
1117                 raise RemoteExecutorException(
1118                     f'{bundle}: This bundle can\'t be completed despite several backups and retries',
1119                 )
1120             else:
1121                 logger.error(
1122                     '%s: At least it\'s only a backup; better luck with the others.', bundle
1123                 )
1124             return None
1125         else:
1126             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1127             logger.warning(msg)
1128             warnings.warn(msg)
1129             return self.launch(bundle, avoid_last_machine)
1130
1131     @overrides
1132     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1133         if self.already_shutdown:
1134             raise Exception('Submitted work after shutdown.')
1135         pickle = make_cloud_pickle(function, *args, **kwargs)
1136         bundle = self.create_original_bundle(pickle, function.__name__)
1137         self.total_bundles_submitted += 1
1138         return self._helper_executor.submit(self.launch, bundle)
1139
1140     @overrides
1141     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
1142         if not self.already_shutdown:
1143             logging.debug('Shutting down RemoteExecutor %s', self.title)
1144             self.heartbeat_stop_event.set()
1145             self.heartbeat_thread.join()
1146             self._helper_executor.shutdown(wait)
1147             if not quiet:
1148                 print(self.histogram.__repr__(label_formatter='%ds'))
1149             self.already_shutdown = True
1150
1151
1152 @singleton
1153 class DefaultExecutors(object):
1154     """A container for a default thread, process and remote executor.
1155     These are not created until needed and we take care to clean up
1156     before process exit.
1157
1158     """
1159
1160     def __init__(self):
1161         self.thread_executor: Optional[ThreadExecutor] = None
1162         self.process_executor: Optional[ProcessExecutor] = None
1163         self.remote_executor: Optional[RemoteExecutor] = None
1164
1165     @staticmethod
1166     def ping(host) -> bool:
1167         logger.debug('RUN> ping -c 1 %s', host)
1168         try:
1169             x = cmd_with_timeout(f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0)
1170             return x == 0
1171         except Exception:
1172             return False
1173
1174     def thread_pool(self) -> ThreadExecutor:
1175         if self.thread_executor is None:
1176             self.thread_executor = ThreadExecutor()
1177         return self.thread_executor
1178
1179     def process_pool(self) -> ProcessExecutor:
1180         if self.process_executor is None:
1181             self.process_executor = ProcessExecutor()
1182         return self.process_executor
1183
1184     def remote_pool(self) -> RemoteExecutor:
1185         if self.remote_executor is None:
1186             logger.info('Looking for some helper machines...')
1187             pool: List[RemoteWorkerRecord] = []
1188             if self.ping('cheetah.house'):
1189                 logger.info('Found cheetah.house')
1190                 pool.append(
1191                     RemoteWorkerRecord(
1192                         username='scott',
1193                         machine='cheetah.house',
1194                         weight=30,
1195                         count=6,
1196                     ),
1197                 )
1198             if self.ping('meerkat.cabin'):
1199                 logger.info('Found meerkat.cabin')
1200                 pool.append(
1201                     RemoteWorkerRecord(
1202                         username='scott',
1203                         machine='meerkat.cabin',
1204                         weight=12,
1205                         count=2,
1206                     ),
1207                 )
1208             if self.ping('wannabe.house'):
1209                 logger.info('Found wannabe.house')
1210                 pool.append(
1211                     RemoteWorkerRecord(
1212                         username='scott',
1213                         machine='wannabe.house',
1214                         weight=25,
1215                         count=10,
1216                     ),
1217                 )
1218             if self.ping('puma.cabin'):
1219                 logger.info('Found puma.cabin')
1220                 pool.append(
1221                     RemoteWorkerRecord(
1222                         username='scott',
1223                         machine='puma.cabin',
1224                         weight=30,
1225                         count=6,
1226                     ),
1227                 )
1228             if self.ping('backup.house'):
1229                 logger.info('Found backup.house')
1230                 pool.append(
1231                     RemoteWorkerRecord(
1232                         username='scott',
1233                         machine='backup.house',
1234                         weight=8,
1235                         count=2,
1236                     ),
1237                 )
1238
1239             # The controller machine has a lot to do; go easy on it.
1240             for record in pool:
1241                 if record.machine == platform.node() and record.count > 1:
1242                     logger.info('Reducing workload for %s.', record.machine)
1243                     record.count = 1
1244
1245             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1246             policy.register_worker_pool(pool)
1247             self.remote_executor = RemoteExecutor(pool, policy)
1248         return self.remote_executor
1249
1250     def shutdown(self) -> None:
1251         if self.thread_executor is not None:
1252             self.thread_executor.shutdown(wait=True, quiet=True)
1253             self.thread_executor = None
1254         if self.process_executor is not None:
1255             self.process_executor.shutdown(wait=True, quiet=True)
1256             self.process_executor = None
1257         if self.remote_executor is not None:
1258             self.remote_executor.shutdown(wait=True, quiet=True)
1259             self.remote_executor = None