1077667147051e427b5fad045774e32c55d1a142
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3
4 """Defines three executors: a thread executor for doing work using a
5 threadpool, a process executor for doing work in other processes on
6 the same machine and a remote executor for farming out work to other
7 machines.
8
9 Also defines DefaultExecutors which is a container for references to
10 global executors / worker pools with automatic shutdown semantics."""
11
12 from __future__ import annotations
13 import concurrent.futures as fut
14 import logging
15 import os
16 import platform
17 import random
18 import subprocess
19 import threading
20 import time
21 import warnings
22 from abc import ABC, abstractmethod
23 from collections import defaultdict
24 from dataclasses import dataclass
25 from typing import Any, Callable, Dict, List, Optional, Set
26
27 import cloudpickle  # type: ignore
28 import numpy
29 from overrides import overrides
30
31 import argparse_utils
32 import config
33 import histogram as hist
34 from ansi import bg, fg, reset, underline
35 from decorator_utils import singleton
36 from exec_utils import cmd_in_background, cmd_with_timeout, run_silently
37 from thread_utils import background_thread
38
39 logger = logging.getLogger(__name__)
40
41 parser = config.add_commandline_args(
42     f"Executors ({__file__})", "Args related to processing executors."
43 )
44 parser.add_argument(
45     '--executors_threadpool_size',
46     type=int,
47     metavar='#THREADS',
48     help='Number of threads in the default threadpool, leave unset for default',
49     default=None,
50 )
51 parser.add_argument(
52     '--executors_processpool_size',
53     type=int,
54     metavar='#PROCESSES',
55     help='Number of processes in the default processpool, leave unset for default',
56     default=None,
57 )
58 parser.add_argument(
59     '--executors_schedule_remote_backups',
60     default=True,
61     action=argparse_utils.ActionNoYes,
62     help='Should we schedule duplicative backup work if a remote bundle is slow',
63 )
64 parser.add_argument(
65     '--executors_max_bundle_failures',
66     type=int,
67     default=3,
68     metavar='#FAILURES',
69     help='Maximum number of failures before giving up on a bundle',
70 )
71
72 SSH = '/usr/bin/ssh -oForwardX11=no'
73 SCP = '/usr/bin/scp -C'
74
75
76 def make_cloud_pickle(fun, *args, **kwargs):
77     logger.debug("Making cloudpickled bundle at %s", fun.__name__)
78     return cloudpickle.dumps((fun, args, kwargs))
79
80
81 class BaseExecutor(ABC):
82     """The base executor interface definition."""
83
84     def __init__(self, *, title=''):
85         self.title = title
86         self.histogram = hist.SimpleHistogram(
87             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
88         )
89         self.task_count = 0
90
91     @abstractmethod
92     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
93         pass
94
95     @abstractmethod
96     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
97         pass
98
99     def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
100         """Shutdown the executor and return True if the executor is idle
101         (i.e. there are no pending or active tasks).  Return False
102         otherwise.  Note: this should only be called by the launcher
103         process.
104
105         """
106         if self.task_count == 0:
107             self.shutdown(wait=True, quiet=quiet)
108             return True
109         return False
110
111     def adjust_task_count(self, delta: int) -> None:
112         """Change the task count.  Note: do not call this method from a
113         worker, it should only be called by the launcher process /
114         thread / machine.
115
116         """
117         self.task_count += delta
118         logger.debug('Adjusted task count by %d to %d.', delta, self.task_count)
119
120     def get_task_count(self) -> int:
121         """Change the task count.  Note: do not call this method from a
122         worker, it should only be called by the launcher process /
123         thread / machine.
124
125         """
126         return self.task_count
127
128
129 class ThreadExecutor(BaseExecutor):
130     """A threadpool executor instance."""
131
132     def __init__(self, max_workers: Optional[int] = None):
133         super().__init__()
134         workers = None
135         if max_workers is not None:
136             workers = max_workers
137         elif 'executors_threadpool_size' in config.config:
138             workers = config.config['executors_threadpool_size']
139         logger.debug('Creating threadpool executor with %d workers', workers)
140         self._thread_pool_executor = fut.ThreadPoolExecutor(
141             max_workers=workers, thread_name_prefix="thread_executor_helper"
142         )
143         self.already_shutdown = False
144
145     # This is run on a different thread; do not adjust task count here.
146     @staticmethod
147     def run_local_bundle(fun, *args, **kwargs):
148         logger.debug("Running local bundle at %s", fun.__name__)
149         result = fun(*args, **kwargs)
150         return result
151
152     @overrides
153     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
154         if self.already_shutdown:
155             raise Exception('Submitted work after shutdown.')
156         self.adjust_task_count(+1)
157         newargs = []
158         newargs.append(function)
159         for arg in args:
160             newargs.append(arg)
161         start = time.time()
162         result = self._thread_pool_executor.submit(
163             ThreadExecutor.run_local_bundle, *newargs, **kwargs
164         )
165         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
166         result.add_done_callback(lambda _: self.adjust_task_count(-1))
167         return result
168
169     @overrides
170     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
171         if not self.already_shutdown:
172             logger.debug('Shutting down threadpool executor %s', self.title)
173             self._thread_pool_executor.shutdown(wait)
174             if not quiet:
175                 print(self.histogram.__repr__(label_formatter='%ds'))
176             self.already_shutdown = True
177
178
179 class ProcessExecutor(BaseExecutor):
180     """A processpool executor."""
181
182     def __init__(self, max_workers=None):
183         super().__init__()
184         workers = None
185         if max_workers is not None:
186             workers = max_workers
187         elif 'executors_processpool_size' in config.config:
188             workers = config.config['executors_processpool_size']
189         logger.debug('Creating processpool executor with %d workers.', workers)
190         self._process_executor = fut.ProcessPoolExecutor(
191             max_workers=workers,
192         )
193         self.already_shutdown = False
194
195     # This is run in another process; do not adjust task count here.
196     @staticmethod
197     def run_cloud_pickle(pickle):
198         fun, args, kwargs = cloudpickle.loads(pickle)
199         logger.debug("Running pickled bundle at %s", fun.__name__)
200         result = fun(*args, **kwargs)
201         return result
202
203     @overrides
204     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
205         if self.already_shutdown:
206             raise Exception('Submitted work after shutdown.')
207         start = time.time()
208         self.adjust_task_count(+1)
209         pickle = make_cloud_pickle(function, *args, **kwargs)
210         result = self._process_executor.submit(ProcessExecutor.run_cloud_pickle, pickle)
211         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
212         result.add_done_callback(lambda _: self.adjust_task_count(-1))
213         return result
214
215     @overrides
216     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
217         if not self.already_shutdown:
218             logger.debug('Shutting down processpool executor %s', self.title)
219             self._process_executor.shutdown(wait)
220             if not quiet:
221                 print(self.histogram.__repr__(label_formatter='%ds'))
222             self.already_shutdown = True
223
224     def __getstate__(self):
225         state = self.__dict__.copy()
226         state['_process_executor'] = None
227         return state
228
229
230 class RemoteExecutorException(Exception):
231     """Thrown when a bundle cannot be executed despite several retries."""
232
233     pass
234
235
236 @dataclass
237 class RemoteWorkerRecord:
238     """A record of info about a remote worker."""
239
240     username: str
241     machine: str
242     weight: int
243     count: int
244
245     def __hash__(self):
246         return hash((self.username, self.machine))
247
248     def __repr__(self):
249         return f'{self.username}@{self.machine}'
250
251
252 @dataclass
253 class BundleDetails:
254     """All info necessary to define some unit of work that needs to be
255     done, where it is being run, its state, whether it is an original
256     bundle of a backup bundle, how many times it has failed, etc...
257
258     """
259
260     pickled_code: bytes
261     uuid: str
262     fname: str
263     worker: Optional[RemoteWorkerRecord]
264     username: Optional[str]
265     machine: Optional[str]
266     hostname: str
267     code_file: str
268     result_file: str
269     pid: int
270     start_ts: float
271     end_ts: float
272     slower_than_local_p95: bool
273     slower_than_global_p95: bool
274     src_bundle: Optional[BundleDetails]
275     is_cancelled: threading.Event
276     was_cancelled: bool
277     backup_bundles: Optional[List[BundleDetails]]
278     failure_count: int
279
280     def __repr__(self):
281         uuid = self.uuid
282         if uuid[-9:-2] == '_backup':
283             uuid = uuid[:-9]
284             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
285         else:
286             suffix = uuid[-6:]
287
288         colorz = [
289             fg('violet red'),
290             fg('red'),
291             fg('orange'),
292             fg('peach orange'),
293             fg('yellow'),
294             fg('marigold yellow'),
295             fg('green yellow'),
296             fg('tea green'),
297             fg('cornflower blue'),
298             fg('turquoise blue'),
299             fg('tropical blue'),
300             fg('lavender purple'),
301             fg('medium purple'),
302         ]
303         c = colorz[int(uuid[-2:], 16) % len(colorz)]
304         fname = self.fname if self.fname is not None else 'nofname'
305         machine = self.machine if self.machine is not None else 'nomachine'
306         return f'{c}{suffix}/{fname}/{machine}{reset()}'
307
308
309 class RemoteExecutorStatus:
310     """A status 'scoreboard' for a remote executor."""
311
312     def __init__(self, total_worker_count: int) -> None:
313         self.worker_count: int = total_worker_count
314         self.known_workers: Set[RemoteWorkerRecord] = set()
315         self.start_time: float = time.time()
316         self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
317         self.end_per_bundle: Dict[str, float] = defaultdict(float)
318         self.finished_bundle_timings_per_worker: Dict[RemoteWorkerRecord, List[float]] = {}
319         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
320         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
321         self.finished_bundle_timings: List[float] = []
322         self.last_periodic_dump: Optional[float] = None
323         self.total_bundles_submitted: int = 0
324
325         # Protects reads and modification using self.  Also used
326         # as a memory fence for modifications to bundle.
327         self.lock: threading.Lock = threading.Lock()
328
329     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
330         with self.lock:
331             self.record_acquire_worker_already_locked(worker, uuid)
332
333     def record_acquire_worker_already_locked(self, worker: RemoteWorkerRecord, uuid: str) -> None:
334         assert self.lock.locked()
335         self.known_workers.add(worker)
336         self.start_per_bundle[uuid] = None
337         x = self.in_flight_bundles_by_worker.get(worker, set())
338         x.add(uuid)
339         self.in_flight_bundles_by_worker[worker] = x
340
341     def record_bundle_details(self, details: BundleDetails) -> None:
342         with self.lock:
343             self.record_bundle_details_already_locked(details)
344
345     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
346         assert self.lock.locked()
347         self.bundle_details_by_uuid[details.uuid] = details
348
349     def record_release_worker(
350         self,
351         worker: RemoteWorkerRecord,
352         uuid: str,
353         was_cancelled: bool,
354     ) -> None:
355         with self.lock:
356             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
357
358     def record_release_worker_already_locked(
359         self,
360         worker: RemoteWorkerRecord,
361         uuid: str,
362         was_cancelled: bool,
363     ) -> None:
364         assert self.lock.locked()
365         ts = time.time()
366         self.end_per_bundle[uuid] = ts
367         self.in_flight_bundles_by_worker[worker].remove(uuid)
368         if not was_cancelled:
369             start = self.start_per_bundle[uuid]
370             assert start is not None
371             bundle_latency = ts - start
372             x = self.finished_bundle_timings_per_worker.get(worker, [])
373             x.append(bundle_latency)
374             self.finished_bundle_timings_per_worker[worker] = x
375             self.finished_bundle_timings.append(bundle_latency)
376
377     def record_processing_began(self, uuid: str):
378         with self.lock:
379             self.start_per_bundle[uuid] = time.time()
380
381     def total_in_flight(self) -> int:
382         assert self.lock.locked()
383         total_in_flight = 0
384         for worker in self.known_workers:
385             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
386         return total_in_flight
387
388     def total_idle(self) -> int:
389         assert self.lock.locked()
390         return self.worker_count - self.total_in_flight()
391
392     def __repr__(self):
393         assert self.lock.locked()
394         ts = time.time()
395         total_finished = len(self.finished_bundle_timings)
396         total_in_flight = self.total_in_flight()
397         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
398         qall = None
399         if len(self.finished_bundle_timings) > 1:
400             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
401             ret += (
402                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, total={ts-self.start_time:.1f}s, '
403                 f'✅={total_finished}/{self.total_bundles_submitted}, '
404                 f'💻n={total_in_flight}/{self.worker_count}\n'
405             )
406         else:
407             ret += (
408                 f'⏱={ts-self.start_time:.1f}s, '
409                 f'✅={total_finished}/{self.total_bundles_submitted}, '
410                 f'💻n={total_in_flight}/{self.worker_count}\n'
411             )
412
413         for worker in self.known_workers:
414             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
415             timings = self.finished_bundle_timings_per_worker.get(worker, [])
416             count = len(timings)
417             qworker = None
418             if count > 1:
419                 qworker = numpy.quantile(timings, [0.5, 0.95])
420                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
421             else:
422                 ret += '\n'
423             if count > 0:
424                 ret += f'    ...finished {count} total bundle(s) so far\n'
425             in_flight = len(self.in_flight_bundles_by_worker[worker])
426             if in_flight > 0:
427                 ret += f'    ...{in_flight} bundles currently in flight:\n'
428                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
429                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
430                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
431                     if self.start_per_bundle[bundle_uuid] is not None:
432                         sec = ts - self.start_per_bundle[bundle_uuid]
433                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
434                     else:
435                         ret += f'       {details} setting up / copying data...'
436                         sec = 0.0
437
438                     if qworker is not None:
439                         if sec > qworker[1]:
440                             ret += f'{bg("red")}>💻p95{reset()} '
441                             if details is not None:
442                                 details.slower_than_local_p95 = True
443                         else:
444                             if details is not None:
445                                 details.slower_than_local_p95 = False
446
447                     if qall is not None:
448                         if sec > qall[1]:
449                             ret += f'{bg("red")}>∀p95{reset()} '
450                             if details is not None:
451                                 details.slower_than_global_p95 = True
452                         else:
453                             details.slower_than_global_p95 = False
454                     ret += '\n'
455         return ret
456
457     def periodic_dump(self, total_bundles_submitted: int) -> None:
458         assert self.lock.locked()
459         self.total_bundles_submitted = total_bundles_submitted
460         ts = time.time()
461         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
462             print(self)
463             self.last_periodic_dump = ts
464
465
466 class RemoteWorkerSelectionPolicy(ABC):
467     """A policy for selecting a remote worker base class."""
468
469     def __init__(self):
470         self.workers: Optional[List[RemoteWorkerRecord]] = None
471
472     def register_worker_pool(self, workers: List[RemoteWorkerRecord]):
473         self.workers = workers
474
475     @abstractmethod
476     def is_worker_available(self) -> bool:
477         pass
478
479     @abstractmethod
480     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
481         pass
482
483
484 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
485     """A remote worker selector that uses weighted RNG."""
486
487     @overrides
488     def is_worker_available(self) -> bool:
489         if self.workers:
490             for worker in self.workers:
491                 if worker.count > 0:
492                     return True
493         return False
494
495     @overrides
496     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
497         grabbag = []
498         if self.workers:
499             for worker in self.workers:
500                 if worker.machine != machine_to_avoid:
501                     if worker.count > 0:
502                         for _ in range(worker.count * worker.weight):
503                             grabbag.append(worker)
504
505         if len(grabbag) == 0:
506             logger.debug('There are no available workers that avoid %s', machine_to_avoid)
507             if self.workers:
508                 for worker in self.workers:
509                     if worker.count > 0:
510                         for _ in range(worker.count * worker.weight):
511                             grabbag.append(worker)
512
513         if len(grabbag) == 0:
514             logger.warning('There are no available workers?!')
515             return None
516
517         worker = random.sample(grabbag, 1)[0]
518         assert worker.count > 0
519         worker.count -= 1
520         logger.debug('Selected worker %s', worker)
521         return worker
522
523
524 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
525     """A remote worker selector that just round robins."""
526
527     def __init__(self) -> None:
528         super().__init__()
529         self.index = 0
530
531     @overrides
532     def is_worker_available(self) -> bool:
533         if self.workers:
534             for worker in self.workers:
535                 if worker.count > 0:
536                     return True
537         return False
538
539     @overrides
540     def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
541         if self.workers:
542             x = self.index
543             while True:
544                 worker = self.workers[x]
545                 if worker.count > 0:
546                     worker.count -= 1
547                     x += 1
548                     if x >= len(self.workers):
549                         x = 0
550                     self.index = x
551                     logger.debug('Selected worker %s', worker)
552                     return worker
553                 x += 1
554                 if x >= len(self.workers):
555                     x = 0
556                 if x == self.index:
557                     logger.warning('Unexpectedly could not find a worker, retrying...')
558                     return None
559         return None
560
561
562 class RemoteExecutor(BaseExecutor):
563     """A remote work executor."""
564
565     def __init__(
566         self,
567         workers: List[RemoteWorkerRecord],
568         policy: RemoteWorkerSelectionPolicy,
569     ) -> None:
570         super().__init__()
571         self.workers = workers
572         self.policy = policy
573         self.worker_count = 0
574         for worker in self.workers:
575             self.worker_count += worker.count
576         if self.worker_count <= 0:
577             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
578             logger.critical(msg)
579             raise RemoteExecutorException(msg)
580         self.policy.register_worker_pool(self.workers)
581         self.cv = threading.Condition()
582         logger.debug('Creating %d local threads, one per remote worker.', self.worker_count)
583         self._helper_executor = fut.ThreadPoolExecutor(
584             thread_name_prefix="remote_executor_helper",
585             max_workers=self.worker_count,
586         )
587         self.status = RemoteExecutorStatus(self.worker_count)
588         self.total_bundles_submitted = 0
589         self.backup_lock = threading.Lock()
590         self.last_backup = None
591         (
592             self.heartbeat_thread,
593             self.heartbeat_stop_event,
594         ) = self.run_periodic_heartbeat()
595         self.already_shutdown = False
596
597     @background_thread
598     def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
599         while not stop_event.is_set():
600             time.sleep(5.0)
601             logger.debug('Running periodic heartbeat code...')
602             self.heartbeat()
603         logger.debug('Periodic heartbeat thread shutting down.')
604
605     def heartbeat(self) -> None:
606         # Note: this is invoked on a background thread, not an
607         # executor thread.  Be careful what you do with it b/c it
608         # needs to get back and dump status again periodically.
609         with self.status.lock:
610             self.status.periodic_dump(self.total_bundles_submitted)
611
612             # Look for bundles to reschedule via executor.submit
613             if config.config['executors_schedule_remote_backups']:
614                 self.maybe_schedule_backup_bundles()
615
616     def maybe_schedule_backup_bundles(self):
617         assert self.status.lock.locked()
618         num_done = len(self.status.finished_bundle_timings)
619         num_idle_workers = self.worker_count - self.task_count
620         now = time.time()
621         if (
622             num_done > 2
623             and num_idle_workers > 1
624             and (self.last_backup is None or (now - self.last_backup > 9.0))
625             and self.backup_lock.acquire(blocking=False)
626         ):
627             try:
628                 assert self.backup_lock.locked()
629
630                 bundle_to_backup = None
631                 best_score = None
632                 for (
633                     worker,
634                     bundle_uuids,
635                 ) in self.status.in_flight_bundles_by_worker.items():
636
637                     # Prefer to schedule backups of bundles running on
638                     # slower machines.
639                     base_score = 0
640                     for record in self.workers:
641                         if worker.machine == record.machine:
642                             base_score = float(record.weight)
643                             base_score = 1.0 / base_score
644                             base_score *= 200.0
645                             base_score = int(base_score)
646                             break
647
648                     for uuid in bundle_uuids:
649                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
650                         if (
651                             bundle is not None
652                             and bundle.src_bundle is None
653                             and bundle.backup_bundles is not None
654                         ):
655                             score = base_score
656
657                             # Schedule backups of bundles running
658                             # longer; especially those that are
659                             # unexpectedly slow.
660                             start_ts = self.status.start_per_bundle[uuid]
661                             if start_ts is not None:
662                                 runtime = now - start_ts
663                                 score += runtime
664                                 logger.debug('score[%s] => %.1f  # latency boost', bundle, score)
665
666                                 if bundle.slower_than_local_p95:
667                                     score += runtime / 2
668                                     logger.debug('score[%s] => %.1f  # >worker p95', bundle, score)
669
670                                 if bundle.slower_than_global_p95:
671                                     score += runtime / 4
672                                     logger.debug('score[%s] => %.1f  # >global p95', bundle, score)
673
674                             # Prefer backups of bundles that don't
675                             # have backups already.
676                             backup_count = len(bundle.backup_bundles)
677                             if backup_count == 0:
678                                 score *= 2
679                             elif backup_count == 1:
680                                 score /= 2
681                             elif backup_count == 2:
682                                 score /= 8
683                             else:
684                                 score = 0
685                             logger.debug(
686                                 'score[%s] => %.1f  # {backup_count} dup backup factor',
687                                 bundle,
688                                 score,
689                             )
690
691                             if score != 0 and (best_score is None or score > best_score):
692                                 bundle_to_backup = bundle
693                                 assert bundle is not None
694                                 assert bundle.backup_bundles is not None
695                                 assert bundle.src_bundle is None
696                                 best_score = score
697
698                 # Note: this is all still happening on the heartbeat
699                 # runner thread.  That's ok because
700                 # schedule_backup_for_bundle uses the executor to
701                 # submit the bundle again which will cause it to be
702                 # picked up by a worker thread and allow this thread
703                 # to return to run future heartbeats.
704                 if bundle_to_backup is not None:
705                     self.last_backup = now
706                     logger.info(
707                         '=====> SCHEDULING BACKUP %s (score=%.1f) <=====',
708                         bundle_to_backup,
709                         best_score,
710                     )
711                     self.schedule_backup_for_bundle(bundle_to_backup)
712             finally:
713                 self.backup_lock.release()
714
715     def is_worker_available(self) -> bool:
716         return self.policy.is_worker_available()
717
718     def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
719         return self.policy.acquire_worker(machine_to_avoid)
720
721     def find_available_worker_or_block(self, machine_to_avoid: str = None) -> RemoteWorkerRecord:
722         with self.cv:
723             while not self.is_worker_available():
724                 self.cv.wait()
725             worker = self.acquire_worker(machine_to_avoid)
726             if worker is not None:
727                 return worker
728         msg = "We should never reach this point in the code"
729         logger.critical(msg)
730         raise Exception(msg)
731
732     def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
733         worker = bundle.worker
734         assert worker is not None
735         logger.debug('Released worker %s', worker)
736         self.status.record_release_worker(
737             worker,
738             bundle.uuid,
739             was_cancelled,
740         )
741         with self.cv:
742             worker.count += 1
743             self.cv.notify()
744         self.adjust_task_count(-1)
745
746     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
747         with self.status.lock:
748             if bundle.is_cancelled.wait(timeout=0.0):
749                 logger.debug('Bundle %s is cancelled, bail out.', bundle.uuid)
750                 bundle.was_cancelled = True
751                 return True
752         return False
753
754     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
755         """Find a worker for bundle or block until one is available."""
756         self.adjust_task_count(+1)
757         uuid = bundle.uuid
758         hostname = bundle.hostname
759         avoid_machine = override_avoid_machine
760         is_original = bundle.src_bundle is None
761
762         # Try not to schedule a backup on the same host as the original.
763         if avoid_machine is None and bundle.src_bundle is not None:
764             avoid_machine = bundle.src_bundle.machine
765         worker = None
766         while worker is None:
767             worker = self.find_available_worker_or_block(avoid_machine)
768         assert worker is not None
769
770         # Ok, found a worker.
771         bundle.worker = worker
772         machine = bundle.machine = worker.machine
773         username = bundle.username = worker.username
774         self.status.record_acquire_worker(worker, uuid)
775         logger.debug('%s: Running bundle on %s...', bundle, worker)
776
777         # Before we do any work, make sure the bundle is still viable.
778         # It may have been some time between when it was submitted and
779         # now due to lack of worker availability and someone else may
780         # have already finished it.
781         if self.check_if_cancelled(bundle):
782             try:
783                 return self.process_work_result(bundle)
784             except Exception as e:
785                 logger.warning('%s: bundle says it\'s cancelled upfront but no results?!', bundle)
786                 self.release_worker(bundle)
787                 if is_original:
788                     # Weird.  We are the original owner of this
789                     # bundle.  For it to have been cancelled, a backup
790                     # must have already started and completed before
791                     # we even for started.  Moreover, the backup says
792                     # it is done but we can't find the results it
793                     # should have copied over.  Reschedule the whole
794                     # thing.
795                     logger.exception(e)
796                     logger.error(
797                         '%s: We are the original owner thread and yet there are '
798                         'no results for this bundle.  This is unexpected and bad.',
799                         bundle,
800                     )
801                     return self.emergency_retry_nasty_bundle(bundle)
802                 else:
803                     # Expected(?).  We're a backup and our bundle is
804                     # cancelled before we even got started.  Something
805                     # went bad in process_work_result (I acutually don't
806                     # see what?) but probably not worth worrying
807                     # about.  Let the original thread worry about
808                     # either finding the results or complaining about
809                     # it.
810                     return None
811
812         # Send input code / data to worker machine if it's not local.
813         if hostname not in machine:
814             try:
815                 cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
816                 start_ts = time.time()
817                 logger.info("%s: Copying work to %s via %s.", bundle, worker, cmd)
818                 run_silently(cmd)
819                 xfer_latency = time.time() - start_ts
820                 logger.debug("%s: Copying to %s took %.1fs.", bundle, worker, xfer_latency)
821             except Exception as e:
822                 self.release_worker(bundle)
823                 if is_original:
824                     # Weird.  We tried to copy the code to the worker and it failed...
825                     # And we're the original bundle.  We have to retry.
826                     logger.exception(e)
827                     logger.error(
828                         "%s: Failed to send instructions to the worker machine?! "
829                         "This is not expected; we\'re the original bundle so this shouldn\'t "
830                         "be a race condition.  Attempting an emergency retry...",
831                         bundle,
832                     )
833                     return self.emergency_retry_nasty_bundle(bundle)
834                 else:
835                     # This is actually expected; we're a backup.
836                     # There's a race condition where someone else
837                     # already finished the work and removed the source
838                     # code file before we could copy it.  No biggie.
839                     logger.warning(
840                         '%s: Failed to send instructions to the worker machine... '
841                         'We\'re a backup and this may be caused by the original (or '
842                         'some other backup) already finishing this work.  Ignoring.',
843                         bundle,
844                     )
845                     return None
846
847         # Kick off the work.  Note that if this fails we let
848         # wait_for_process deal with it.
849         self.status.record_processing_began(uuid)
850         cmd = (
851             f'{SSH} {bundle.username}@{bundle.machine} '
852             f'"source py38-venv/bin/activate &&'
853             f' /home/scott/lib/python_modules/remote_worker.py'
854             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
855         )
856         logger.debug('%s: Executing %s in the background to kick off work...', bundle, cmd)
857         p = cmd_in_background(cmd, silent=True)
858         bundle.pid = p.pid
859         logger.debug('%s: Local ssh process pid=%d; remote worker is %s.', bundle, p.pid, machine)
860         return self.wait_for_process(p, bundle, 0)
861
862     def wait_for_process(
863         self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
864     ) -> Any:
865         machine = bundle.machine
866         assert p is not None
867         pid = p.pid
868         if depth > 3:
869             logger.error(
870                 "I've gotten repeated errors waiting on this bundle; giving up on pid=%d", pid
871             )
872             p.terminate()
873             self.release_worker(bundle)
874             return self.emergency_retry_nasty_bundle(bundle)
875
876         # Spin until either the ssh job we scheduled finishes the
877         # bundle or some backup worker signals that they finished it
878         # before we could.
879         while True:
880             try:
881                 p.wait(timeout=0.25)
882             except subprocess.TimeoutExpired:
883                 if self.check_if_cancelled(bundle):
884                     logger.info('%s: looks like another worker finished bundle...', bundle)
885                     break
886             else:
887                 logger.info("%s: pid %d (%s) is finished!", bundle, pid, machine)
888                 p = None
889                 break
890
891         # If we get here we believe the bundle is done; either the ssh
892         # subprocess finished (hopefully successfully) or we noticed
893         # that some other worker seems to have completed the bundle
894         # and we're bailing out.
895         try:
896             ret = self.process_work_result(bundle)
897             if ret is not None and p is not None:
898                 p.terminate()
899             return ret
900
901         # Something went wrong; e.g. we could not copy the results
902         # back, cleanup after ourselves on the remote machine, or
903         # unpickle the results we got from the remove machine.  If we
904         # still have an active ssh subprocess, keep waiting on it.
905         # Otherwise, time for an emergency reschedule.
906         except Exception as e:
907             logger.exception(e)
908             logger.error('%s: Something unexpected just happened...', bundle)
909             if p is not None:
910                 logger.warning(
911                     "%s: Failed to wrap up \"done\" bundle, re-waiting on active ssh.", bundle
912                 )
913                 return self.wait_for_process(p, bundle, depth + 1)
914             else:
915                 self.release_worker(bundle)
916                 return self.emergency_retry_nasty_bundle(bundle)
917
918     def process_work_result(self, bundle: BundleDetails) -> Any:
919         with self.status.lock:
920             is_original = bundle.src_bundle is None
921             was_cancelled = bundle.was_cancelled
922             username = bundle.username
923             machine = bundle.machine
924             result_file = bundle.result_file
925             code_file = bundle.code_file
926
927             # Whether original or backup, if we finished first we must
928             # fetch the results if the computation happened on a
929             # remote machine.
930             bundle.end_ts = time.time()
931             if not was_cancelled:
932                 assert bundle.machine is not None
933                 if bundle.hostname not in bundle.machine:
934                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
935                     logger.info(
936                         "%s: Fetching results back from %s@%s via %s",
937                         bundle,
938                         username,
939                         machine,
940                         cmd,
941                     )
942
943                     # If either of these throw they are handled in
944                     # wait_for_process.
945                     attempts = 0
946                     while True:
947                         try:
948                             run_silently(cmd)
949                         except Exception as e:
950                             attempts += 1
951                             if attempts >= 3:
952                                 raise e
953                         else:
954                             break
955
956                     run_silently(
957                         f'{SSH} {username}@{machine}' f' "/bin/rm -f {code_file} {result_file}"'
958                     )
959                     logger.debug('Fetching results back took %.2fs', time.time() - bundle.end_ts)
960                 dur = bundle.end_ts - bundle.start_ts
961                 self.histogram.add_item(dur)
962
963         # Only the original worker should unpickle the file contents
964         # though since it's the only one whose result matters.  The
965         # original is also the only job that may delete result_file
966         # from disk.  Note that the original may have been cancelled
967         # if one of the backups finished first; it still must read the
968         # result from disk.
969         if is_original:
970             logger.debug("%s: Unpickling %s.", bundle, result_file)
971             try:
972                 with open(result_file, 'rb') as rb:
973                     serialized = rb.read()
974                 result = cloudpickle.loads(serialized)
975             except Exception as e:
976                 logger.exception(e)
977                 logger.error('Failed to load %s... this is bad news.', result_file)
978                 self.release_worker(bundle)
979
980                 # Re-raise the exception; the code in wait_for_process may
981                 # decide to emergency_retry_nasty_bundle here.
982                 raise e
983             logger.debug('Removing local (master) %s and %s.', code_file, result_file)
984             os.remove(result_file)
985             os.remove(code_file)
986
987             # Notify any backups that the original is done so they
988             # should stop ASAP.  Do this whether or not we
989             # finished first since there could be more than one
990             # backup.
991             if bundle.backup_bundles is not None:
992                 for backup in bundle.backup_bundles:
993                     logger.debug(
994                         '%s: Notifying backup %s that it\'s cancelled', bundle, backup.uuid
995                     )
996                     backup.is_cancelled.set()
997
998         # This is a backup job and, by now, we have already fetched
999         # the bundle results.
1000         else:
1001             # Backup results don't matter, they just need to leave the
1002             # result file in the right place for their originals to
1003             # read/unpickle later.
1004             result = None
1005
1006             # Tell the original to stop if we finished first.
1007             if not was_cancelled:
1008                 orig_bundle = bundle.src_bundle
1009                 assert orig_bundle is not None
1010                 logger.debug(
1011                     '%s: Notifying original %s we beat them to it.', bundle, orig_bundle.uuid
1012                 )
1013                 orig_bundle.is_cancelled.set()
1014         self.release_worker(bundle, was_cancelled=was_cancelled)
1015         return result
1016
1017     def create_original_bundle(self, pickle, fname: str):
1018         from string_utils import generate_uuid
1019
1020         uuid = generate_uuid(omit_dashes=True)
1021         code_file = f'/tmp/{uuid}.code.bin'
1022         result_file = f'/tmp/{uuid}.result.bin'
1023
1024         logger.debug('Writing pickled code to %s', code_file)
1025         with open(code_file, 'wb') as wb:
1026             wb.write(pickle)
1027
1028         bundle = BundleDetails(
1029             pickled_code=pickle,
1030             uuid=uuid,
1031             fname=fname,
1032             worker=None,
1033             username=None,
1034             machine=None,
1035             hostname=platform.node(),
1036             code_file=code_file,
1037             result_file=result_file,
1038             pid=0,
1039             start_ts=time.time(),
1040             end_ts=0.0,
1041             slower_than_local_p95=False,
1042             slower_than_global_p95=False,
1043             src_bundle=None,
1044             is_cancelled=threading.Event(),
1045             was_cancelled=False,
1046             backup_bundles=[],
1047             failure_count=0,
1048         )
1049         self.status.record_bundle_details(bundle)
1050         logger.debug('%s: Created an original bundle', bundle)
1051         return bundle
1052
1053     def create_backup_bundle(self, src_bundle: BundleDetails):
1054         assert src_bundle.backup_bundles is not None
1055         n = len(src_bundle.backup_bundles)
1056         uuid = src_bundle.uuid + f'_backup#{n}'
1057
1058         backup_bundle = BundleDetails(
1059             pickled_code=src_bundle.pickled_code,
1060             uuid=uuid,
1061             fname=src_bundle.fname,
1062             worker=None,
1063             username=None,
1064             machine=None,
1065             hostname=src_bundle.hostname,
1066             code_file=src_bundle.code_file,
1067             result_file=src_bundle.result_file,
1068             pid=0,
1069             start_ts=time.time(),
1070             end_ts=0.0,
1071             slower_than_local_p95=False,
1072             slower_than_global_p95=False,
1073             src_bundle=src_bundle,
1074             is_cancelled=threading.Event(),
1075             was_cancelled=False,
1076             backup_bundles=None,  # backup backups not allowed
1077             failure_count=0,
1078         )
1079         src_bundle.backup_bundles.append(backup_bundle)
1080         self.status.record_bundle_details_already_locked(backup_bundle)
1081         logger.debug('%s: Created a backup bundle', backup_bundle)
1082         return backup_bundle
1083
1084     def schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1085         assert self.status.lock.locked()
1086         assert src_bundle is not None
1087         backup_bundle = self.create_backup_bundle(src_bundle)
1088         logger.debug(
1089             '%s/%s: Scheduling backup for execution...', backup_bundle.uuid, backup_bundle.fname
1090         )
1091         self._helper_executor.submit(self.launch, backup_bundle)
1092
1093         # Results from backups don't matter; if they finish first
1094         # they will move the result_file to this machine and let
1095         # the original pick them up and unpickle them.
1096
1097     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> Optional[fut.Future]:
1098         is_original = bundle.src_bundle is None
1099         bundle.worker = None
1100         avoid_last_machine = bundle.machine
1101         bundle.machine = None
1102         bundle.username = None
1103         bundle.failure_count += 1
1104         if is_original:
1105             retry_limit = 3
1106         else:
1107             retry_limit = 2
1108
1109         if bundle.failure_count > retry_limit:
1110             logger.error(
1111                 '%s: Tried this bundle too many times already (%dx); giving up.',
1112                 bundle,
1113                 retry_limit,
1114             )
1115             if is_original:
1116                 raise RemoteExecutorException(
1117                     f'{bundle}: This bundle can\'t be completed despite several backups and retries',
1118                 )
1119             else:
1120                 logger.error(
1121                     '%s: At least it\'s only a backup; better luck with the others.', bundle
1122                 )
1123             return None
1124         else:
1125             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1126             logger.warning(msg)
1127             warnings.warn(msg)
1128             return self.launch(bundle, avoid_last_machine)
1129
1130     @overrides
1131     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1132         if self.already_shutdown:
1133             raise Exception('Submitted work after shutdown.')
1134         pickle = make_cloud_pickle(function, *args, **kwargs)
1135         bundle = self.create_original_bundle(pickle, function.__name__)
1136         self.total_bundles_submitted += 1
1137         return self._helper_executor.submit(self.launch, bundle)
1138
1139     @overrides
1140     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
1141         if not self.already_shutdown:
1142             logging.debug('Shutting down RemoteExecutor %s', self.title)
1143             self.heartbeat_stop_event.set()
1144             self.heartbeat_thread.join()
1145             self._helper_executor.shutdown(wait)
1146             if not quiet:
1147                 print(self.histogram.__repr__(label_formatter='%ds'))
1148             self.already_shutdown = True
1149
1150
1151 @singleton
1152 class DefaultExecutors(object):
1153     """A container for a default thread, process and remote executor.
1154     These are not created until needed and we take care to clean up
1155     before process exit.
1156
1157     """
1158
1159     def __init__(self):
1160         self.thread_executor: Optional[ThreadExecutor] = None
1161         self.process_executor: Optional[ProcessExecutor] = None
1162         self.remote_executor: Optional[RemoteExecutor] = None
1163
1164     @staticmethod
1165     def ping(host) -> bool:
1166         logger.debug('RUN> ping -c 1 %s', host)
1167         try:
1168             x = cmd_with_timeout(f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0)
1169             return x == 0
1170         except Exception:
1171             return False
1172
1173     def thread_pool(self) -> ThreadExecutor:
1174         if self.thread_executor is None:
1175             self.thread_executor = ThreadExecutor()
1176         return self.thread_executor
1177
1178     def process_pool(self) -> ProcessExecutor:
1179         if self.process_executor is None:
1180             self.process_executor = ProcessExecutor()
1181         return self.process_executor
1182
1183     def remote_pool(self) -> RemoteExecutor:
1184         if self.remote_executor is None:
1185             logger.info('Looking for some helper machines...')
1186             pool: List[RemoteWorkerRecord] = []
1187             if self.ping('cheetah.house'):
1188                 logger.info('Found cheetah.house')
1189                 pool.append(
1190                     RemoteWorkerRecord(
1191                         username='scott',
1192                         machine='cheetah.house',
1193                         weight=24,
1194                         count=6,
1195                     ),
1196                 )
1197             if self.ping('meerkat.cabin'):
1198                 logger.info('Found meerkat.cabin')
1199                 pool.append(
1200                     RemoteWorkerRecord(
1201                         username='scott',
1202                         machine='meerkat.cabin',
1203                         weight=12,
1204                         count=2,
1205                     ),
1206                 )
1207             if self.ping('wannabe.house'):
1208                 logger.info('Found wannabe.house')
1209                 pool.append(
1210                     RemoteWorkerRecord(
1211                         username='scott',
1212                         machine='wannabe.house',
1213                         weight=14,
1214                         count=8,
1215                     ),
1216                 )
1217             if self.ping('puma.cabin'):
1218                 logger.info('Found puma.cabin')
1219                 pool.append(
1220                     RemoteWorkerRecord(
1221                         username='scott',
1222                         machine='puma.cabin',
1223                         weight=24,
1224                         count=6,
1225                     ),
1226                 )
1227             if self.ping('backup.house'):
1228                 logger.info('Found backup.house')
1229                 pool.append(
1230                     RemoteWorkerRecord(
1231                         username='scott',
1232                         machine='backup.house',
1233                         weight=9,
1234                         count=2,
1235                     ),
1236                 )
1237
1238             # The controller machine has a lot to do; go easy on it.
1239             for record in pool:
1240                 if record.machine == platform.node() and record.count > 1:
1241                     logger.info('Reducing workload for %s.', record.machine)
1242                     record.count = 1
1243
1244             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1245             policy.register_worker_pool(pool)
1246             self.remote_executor = RemoteExecutor(pool, policy)
1247         return self.remote_executor
1248
1249     def shutdown(self) -> None:
1250         if self.thread_executor is not None:
1251             self.thread_executor.shutdown(wait=True, quiet=True)
1252             self.thread_executor = None
1253         if self.process_executor is not None:
1254             self.process_executor.shutdown(wait=True, quiet=True)
1255             self.process_executor = None
1256         if self.remote_executor is not None:
1257             self.remote_executor.shutdown(wait=True, quiet=True)
1258             self.remote_executor = None