Used isort to sort imports. Also added to the git pre-commit hook.
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 import concurrent.futures as fut
6 import logging
7 import os
8 import platform
9 import random
10 import subprocess
11 import threading
12 import time
13 import warnings
14 from abc import ABC, abstractmethod
15 from collections import defaultdict
16 from dataclasses import dataclass
17 from typing import Any, Callable, Dict, List, Optional, Set
18
19 import cloudpickle  # type: ignore
20 import numpy
21 from overrides import overrides
22
23 import argparse_utils
24 import config
25 import histogram as hist
26 from ansi import bg, fg, reset, underline
27 from decorator_utils import singleton
28 from exec_utils import cmd_in_background, cmd_with_timeout, run_silently
29 from thread_utils import background_thread
30
31 logger = logging.getLogger(__name__)
32
33 parser = config.add_commandline_args(
34     f"Executors ({__file__})", "Args related to processing executors."
35 )
36 parser.add_argument(
37     '--executors_threadpool_size',
38     type=int,
39     metavar='#THREADS',
40     help='Number of threads in the default threadpool, leave unset for default',
41     default=None,
42 )
43 parser.add_argument(
44     '--executors_processpool_size',
45     type=int,
46     metavar='#PROCESSES',
47     help='Number of processes in the default processpool, leave unset for default',
48     default=None,
49 )
50 parser.add_argument(
51     '--executors_schedule_remote_backups',
52     default=True,
53     action=argparse_utils.ActionNoYes,
54     help='Should we schedule duplicative backup work if a remote bundle is slow',
55 )
56 parser.add_argument(
57     '--executors_max_bundle_failures',
58     type=int,
59     default=3,
60     metavar='#FAILURES',
61     help='Maximum number of failures before giving up on a bundle',
62 )
63
64 SSH = '/usr/bin/ssh -oForwardX11=no'
65 SCP = '/usr/bin/scp -C'
66
67
68 def make_cloud_pickle(fun, *args, **kwargs):
69     logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
70     return cloudpickle.dumps((fun, args, kwargs))
71
72
73 class BaseExecutor(ABC):
74     def __init__(self, *, title=''):
75         self.title = title
76         self.histogram = hist.SimpleHistogram(
77             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
78         )
79         self.task_count = 0
80
81     @abstractmethod
82     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
83         pass
84
85     @abstractmethod
86     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
87         pass
88
89     def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
90         """Shutdown the executor and return True if the executor is idle
91         (i.e. there are no pending or active tasks).  Return False
92         otherwise.  Note: this should only be called by the launcher
93         process.
94
95         """
96         if self.task_count == 0:
97             self.shutdown(wait=True, quiet=quiet)
98             return True
99         return False
100
101     def adjust_task_count(self, delta: int) -> None:
102         """Change the task count.  Note: do not call this method from a
103         worker, it should only be called by the launcher process /
104         thread / machine.
105
106         """
107         self.task_count += delta
108         logger.debug(f'Adjusted task count by {delta} to {self.task_count}')
109
110     def get_task_count(self) -> int:
111         """Change the task count.  Note: do not call this method from a
112         worker, it should only be called by the launcher process /
113         thread / machine.
114
115         """
116         return self.task_count
117
118
119 class ThreadExecutor(BaseExecutor):
120     def __init__(self, max_workers: Optional[int] = None):
121         super().__init__()
122         workers = None
123         if max_workers is not None:
124             workers = max_workers
125         elif 'executors_threadpool_size' in config.config:
126             workers = config.config['executors_threadpool_size']
127         logger.debug(f'Creating threadpool executor with {workers} workers')
128         self._thread_pool_executor = fut.ThreadPoolExecutor(
129             max_workers=workers, thread_name_prefix="thread_executor_helper"
130         )
131         self.already_shutdown = False
132
133     # This is run on a different thread; do not adjust task count here.
134     def run_local_bundle(self, fun, *args, **kwargs):
135         logger.debug(f"Running local bundle at {fun.__name__}")
136         result = fun(*args, **kwargs)
137         return result
138
139     @overrides
140     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
141         if self.already_shutdown:
142             raise Exception('Submitted work after shutdown.')
143         self.adjust_task_count(+1)
144         newargs = []
145         newargs.append(function)
146         for arg in args:
147             newargs.append(arg)
148         start = time.time()
149         result = self._thread_pool_executor.submit(
150             self.run_local_bundle, *newargs, **kwargs
151         )
152         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
153         result.add_done_callback(lambda _: self.adjust_task_count(-1))
154         return result
155
156     @overrides
157     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
158         if not self.already_shutdown:
159             logger.debug(f'Shutting down threadpool executor {self.title}')
160             self._thread_pool_executor.shutdown(wait)
161             if not quiet:
162                 print(self.histogram.__repr__(label_formatter='%ds'))
163             self.already_shutdown = True
164
165
166 class ProcessExecutor(BaseExecutor):
167     def __init__(self, max_workers=None):
168         super().__init__()
169         workers = None
170         if max_workers is not None:
171             workers = max_workers
172         elif 'executors_processpool_size' in config.config:
173             workers = config.config['executors_processpool_size']
174         logger.debug(f'Creating processpool executor with {workers} workers.')
175         self._process_executor = fut.ProcessPoolExecutor(
176             max_workers=workers,
177         )
178         self.already_shutdown = False
179
180     # This is run in another process; do not adjust task count here.
181     def run_cloud_pickle(self, pickle):
182         fun, args, kwargs = cloudpickle.loads(pickle)
183         logger.debug(f"Running pickled bundle at {fun.__name__}")
184         result = fun(*args, **kwargs)
185         return result
186
187     @overrides
188     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
189         if self.already_shutdown:
190             raise Exception('Submitted work after shutdown.')
191         start = time.time()
192         self.adjust_task_count(+1)
193         pickle = make_cloud_pickle(function, *args, **kwargs)
194         result = self._process_executor.submit(self.run_cloud_pickle, pickle)
195         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
196         result.add_done_callback(lambda _: self.adjust_task_count(-1))
197         return result
198
199     @overrides
200     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
201         if not self.already_shutdown:
202             logger.debug(f'Shutting down processpool executor {self.title}')
203             self._process_executor.shutdown(wait)
204             if not quiet:
205                 print(self.histogram.__repr__(label_formatter='%ds'))
206             self.already_shutdown = True
207
208     def __getstate__(self):
209         state = self.__dict__.copy()
210         state['_process_executor'] = None
211         return state
212
213
214 class RemoteExecutorException(Exception):
215     """Thrown when a bundle cannot be executed despite several retries."""
216
217     pass
218
219
220 @dataclass
221 class RemoteWorkerRecord:
222     username: str
223     machine: str
224     weight: int
225     count: int
226
227     def __hash__(self):
228         return hash((self.username, self.machine))
229
230     def __repr__(self):
231         return f'{self.username}@{self.machine}'
232
233
234 @dataclass
235 class BundleDetails:
236     pickled_code: bytes
237     uuid: str
238     fname: str
239     worker: Optional[RemoteWorkerRecord]
240     username: Optional[str]
241     machine: Optional[str]
242     hostname: str
243     code_file: str
244     result_file: str
245     pid: int
246     start_ts: float
247     end_ts: float
248     slower_than_local_p95: bool
249     slower_than_global_p95: bool
250     src_bundle: Optional[BundleDetails]
251     is_cancelled: threading.Event
252     was_cancelled: bool
253     backup_bundles: Optional[List[BundleDetails]]
254     failure_count: int
255
256     def __repr__(self):
257         uuid = self.uuid
258         if uuid[-9:-2] == '_backup':
259             uuid = uuid[:-9]
260             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
261         else:
262             suffix = uuid[-6:]
263
264         colorz = [
265             fg('violet red'),
266             fg('red'),
267             fg('orange'),
268             fg('peach orange'),
269             fg('yellow'),
270             fg('marigold yellow'),
271             fg('green yellow'),
272             fg('tea green'),
273             fg('cornflower blue'),
274             fg('turquoise blue'),
275             fg('tropical blue'),
276             fg('lavender purple'),
277             fg('medium purple'),
278         ]
279         c = colorz[int(uuid[-2:], 16) % len(colorz)]
280         fname = self.fname if self.fname is not None else 'nofname'
281         machine = self.machine if self.machine is not None else 'nomachine'
282         return f'{c}{suffix}/{fname}/{machine}{reset()}'
283
284
285 class RemoteExecutorStatus:
286     def __init__(self, total_worker_count: int) -> None:
287         self.worker_count: int = total_worker_count
288         self.known_workers: Set[RemoteWorkerRecord] = set()
289         self.start_time: float = time.time()
290         self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
291         self.end_per_bundle: Dict[str, float] = defaultdict(float)
292         self.finished_bundle_timings_per_worker: Dict[
293             RemoteWorkerRecord, List[float]
294         ] = {}
295         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
296         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
297         self.finished_bundle_timings: List[float] = []
298         self.last_periodic_dump: Optional[float] = None
299         self.total_bundles_submitted: int = 0
300
301         # Protects reads and modification using self.  Also used
302         # as a memory fence for modifications to bundle.
303         self.lock: threading.Lock = threading.Lock()
304
305     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
306         with self.lock:
307             self.record_acquire_worker_already_locked(worker, uuid)
308
309     def record_acquire_worker_already_locked(
310         self, worker: RemoteWorkerRecord, uuid: str
311     ) -> None:
312         assert self.lock.locked()
313         self.known_workers.add(worker)
314         self.start_per_bundle[uuid] = None
315         x = self.in_flight_bundles_by_worker.get(worker, set())
316         x.add(uuid)
317         self.in_flight_bundles_by_worker[worker] = x
318
319     def record_bundle_details(self, details: BundleDetails) -> None:
320         with self.lock:
321             self.record_bundle_details_already_locked(details)
322
323     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
324         assert self.lock.locked()
325         self.bundle_details_by_uuid[details.uuid] = details
326
327     def record_release_worker(
328         self,
329         worker: RemoteWorkerRecord,
330         uuid: str,
331         was_cancelled: bool,
332     ) -> None:
333         with self.lock:
334             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
335
336     def record_release_worker_already_locked(
337         self,
338         worker: RemoteWorkerRecord,
339         uuid: str,
340         was_cancelled: bool,
341     ) -> None:
342         assert self.lock.locked()
343         ts = time.time()
344         self.end_per_bundle[uuid] = ts
345         self.in_flight_bundles_by_worker[worker].remove(uuid)
346         if not was_cancelled:
347             start = self.start_per_bundle[uuid]
348             assert start
349             bundle_latency = ts - start
350             x = self.finished_bundle_timings_per_worker.get(worker, list())
351             x.append(bundle_latency)
352             self.finished_bundle_timings_per_worker[worker] = x
353             self.finished_bundle_timings.append(bundle_latency)
354
355     def record_processing_began(self, uuid: str):
356         with self.lock:
357             self.start_per_bundle[uuid] = time.time()
358
359     def total_in_flight(self) -> int:
360         assert self.lock.locked()
361         total_in_flight = 0
362         for worker in self.known_workers:
363             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
364         return total_in_flight
365
366     def total_idle(self) -> int:
367         assert self.lock.locked()
368         return self.worker_count - self.total_in_flight()
369
370     def __repr__(self):
371         assert self.lock.locked()
372         ts = time.time()
373         total_finished = len(self.finished_bundle_timings)
374         total_in_flight = self.total_in_flight()
375         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
376         qall = None
377         if len(self.finished_bundle_timings) > 1:
378             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
379             ret += (
380                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, total={ts-self.start_time:.1f}s, '
381                 f'✅={total_finished}/{self.total_bundles_submitted}, '
382                 f'💻n={total_in_flight}/{self.worker_count}\n'
383             )
384         else:
385             ret += (
386                 f'⏱={ts-self.start_time:.1f}s, '
387                 f'✅={total_finished}/{self.total_bundles_submitted}, '
388                 f'💻n={total_in_flight}/{self.worker_count}\n'
389             )
390
391         for worker in self.known_workers:
392             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
393             timings = self.finished_bundle_timings_per_worker.get(worker, [])
394             count = len(timings)
395             qworker = None
396             if count > 1:
397                 qworker = numpy.quantile(timings, [0.5, 0.95])
398                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
399             else:
400                 ret += '\n'
401             if count > 0:
402                 ret += f'    ...finished {count} total bundle(s) so far\n'
403             in_flight = len(self.in_flight_bundles_by_worker[worker])
404             if in_flight > 0:
405                 ret += f'    ...{in_flight} bundles currently in flight:\n'
406                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
407                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
408                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
409                     if self.start_per_bundle[bundle_uuid] is not None:
410                         sec = ts - self.start_per_bundle[bundle_uuid]
411                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
412                     else:
413                         ret += f'       {details} setting up / copying data...'
414                         sec = 0.0
415
416                     if qworker is not None:
417                         if sec > qworker[1]:
418                             ret += f'{bg("red")}>💻p95{reset()} '
419                             if details is not None:
420                                 details.slower_than_local_p95 = True
421                         else:
422                             if details is not None:
423                                 details.slower_than_local_p95 = False
424
425                     if qall is not None:
426                         if sec > qall[1]:
427                             ret += f'{bg("red")}>∀p95{reset()} '
428                             if details is not None:
429                                 details.slower_than_global_p95 = True
430                         else:
431                             details.slower_than_global_p95 = False
432                     ret += '\n'
433         return ret
434
435     def periodic_dump(self, total_bundles_submitted: int) -> None:
436         assert self.lock.locked()
437         self.total_bundles_submitted = total_bundles_submitted
438         ts = time.time()
439         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
440             print(self)
441             self.last_periodic_dump = ts
442
443
444 class RemoteWorkerSelectionPolicy(ABC):
445     def register_worker_pool(self, workers):
446         self.workers = workers
447
448     @abstractmethod
449     def is_worker_available(self) -> bool:
450         pass
451
452     @abstractmethod
453     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
454         pass
455
456
457 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
458     @overrides
459     def is_worker_available(self) -> bool:
460         for worker in self.workers:
461             if worker.count > 0:
462                 return True
463         return False
464
465     @overrides
466     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
467         grabbag = []
468         for worker in self.workers:
469             if worker.machine != machine_to_avoid:
470                 if worker.count > 0:
471                     for _ in range(worker.count * worker.weight):
472                         grabbag.append(worker)
473
474         if len(grabbag) == 0:
475             logger.debug(
476                 f'There are no available workers that avoid {machine_to_avoid}...'
477             )
478             for worker in self.workers:
479                 if worker.count > 0:
480                     for _ in range(worker.count * worker.weight):
481                         grabbag.append(worker)
482
483         if len(grabbag) == 0:
484             logger.warning('There are no available workers?!')
485             return None
486
487         worker = random.sample(grabbag, 1)[0]
488         assert worker.count > 0
489         worker.count -= 1
490         logger.debug(f'Chose worker {worker}')
491         return worker
492
493
494 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
495     def __init__(self) -> None:
496         self.index = 0
497
498     @overrides
499     def is_worker_available(self) -> bool:
500         for worker in self.workers:
501             if worker.count > 0:
502                 return True
503         return False
504
505     @overrides
506     def acquire_worker(
507         self, machine_to_avoid: str = None
508     ) -> Optional[RemoteWorkerRecord]:
509         x = self.index
510         while True:
511             worker = self.workers[x]
512             if worker.count > 0:
513                 worker.count -= 1
514                 x += 1
515                 if x >= len(self.workers):
516                     x = 0
517                 self.index = x
518                 logger.debug(f'Selected worker {worker}')
519                 return worker
520             x += 1
521             if x >= len(self.workers):
522                 x = 0
523             if x == self.index:
524                 msg = 'Unexpectedly could not find a worker, retrying...'
525                 logger.warning(msg)
526                 return None
527
528
529 class RemoteExecutor(BaseExecutor):
530     def __init__(
531         self,
532         workers: List[RemoteWorkerRecord],
533         policy: RemoteWorkerSelectionPolicy,
534     ) -> None:
535         super().__init__()
536         self.workers = workers
537         self.policy = policy
538         self.worker_count = 0
539         for worker in self.workers:
540             self.worker_count += worker.count
541         if self.worker_count <= 0:
542             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
543             logger.critical(msg)
544             raise RemoteExecutorException(msg)
545         self.policy.register_worker_pool(self.workers)
546         self.cv = threading.Condition()
547         logger.debug(
548             f'Creating {self.worker_count} local threads, one per remote worker.'
549         )
550         self._helper_executor = fut.ThreadPoolExecutor(
551             thread_name_prefix="remote_executor_helper",
552             max_workers=self.worker_count,
553         )
554         self.status = RemoteExecutorStatus(self.worker_count)
555         self.total_bundles_submitted = 0
556         self.backup_lock = threading.Lock()
557         self.last_backup = None
558         (
559             self.heartbeat_thread,
560             self.heartbeat_stop_event,
561         ) = self.run_periodic_heartbeat()
562         self.already_shutdown = False
563
564     @background_thread
565     def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
566         while not stop_event.is_set():
567             time.sleep(5.0)
568             logger.debug('Running periodic heartbeat code...')
569             self.heartbeat()
570         logger.debug('Periodic heartbeat thread shutting down.')
571
572     def heartbeat(self) -> None:
573         # Note: this is invoked on a background thread, not an
574         # executor thread.  Be careful what you do with it b/c it
575         # needs to get back and dump status again periodically.
576         with self.status.lock:
577             self.status.periodic_dump(self.total_bundles_submitted)
578
579             # Look for bundles to reschedule via executor.submit
580             if config.config['executors_schedule_remote_backups']:
581                 self.maybe_schedule_backup_bundles()
582
583     def maybe_schedule_backup_bundles(self):
584         assert self.status.lock.locked()
585         num_done = len(self.status.finished_bundle_timings)
586         num_idle_workers = self.worker_count - self.task_count
587         now = time.time()
588         if (
589             num_done > 2
590             and num_idle_workers > 1
591             and (self.last_backup is None or (now - self.last_backup > 9.0))
592             and self.backup_lock.acquire(blocking=False)
593         ):
594             try:
595                 assert self.backup_lock.locked()
596
597                 bundle_to_backup = None
598                 best_score = None
599                 for (
600                     worker,
601                     bundle_uuids,
602                 ) in self.status.in_flight_bundles_by_worker.items():
603
604                     # Prefer to schedule backups of bundles running on
605                     # slower machines.
606                     base_score = 0
607                     for record in self.workers:
608                         if worker.machine == record.machine:
609                             base_score = float(record.weight)
610                             base_score = 1.0 / base_score
611                             base_score *= 200.0
612                             base_score = int(base_score)
613                             break
614
615                     for uuid in bundle_uuids:
616                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
617                         if (
618                             bundle is not None
619                             and bundle.src_bundle is None
620                             and bundle.backup_bundles is not None
621                         ):
622                             score = base_score
623
624                             # Schedule backups of bundles running
625                             # longer; especially those that are
626                             # unexpectedly slow.
627                             start_ts = self.status.start_per_bundle[uuid]
628                             if start_ts is not None:
629                                 runtime = now - start_ts
630                                 score += runtime
631                                 logger.debug(
632                                     f'score[{bundle}] => {score}  # latency boost'
633                                 )
634
635                                 if bundle.slower_than_local_p95:
636                                     score += runtime / 2
637                                     logger.debug(
638                                         f'score[{bundle}] => {score}  # >worker p95'
639                                     )
640
641                                 if bundle.slower_than_global_p95:
642                                     score += runtime / 4
643                                     logger.debug(
644                                         f'score[{bundle}] => {score}  # >global p95'
645                                     )
646
647                             # Prefer backups of bundles that don't
648                             # have backups already.
649                             backup_count = len(bundle.backup_bundles)
650                             if backup_count == 0:
651                                 score *= 2
652                             elif backup_count == 1:
653                                 score /= 2
654                             elif backup_count == 2:
655                                 score /= 8
656                             else:
657                                 score = 0
658                             logger.debug(
659                                 f'score[{bundle}] => {score}  # {backup_count} dup backup factor'
660                             )
661
662                             if score != 0 and (
663                                 best_score is None or score > best_score
664                             ):
665                                 bundle_to_backup = bundle
666                                 assert bundle is not None
667                                 assert bundle.backup_bundles is not None
668                                 assert bundle.src_bundle is None
669                                 best_score = score
670
671                 # Note: this is all still happening on the heartbeat
672                 # runner thread.  That's ok because
673                 # schedule_backup_for_bundle uses the executor to
674                 # submit the bundle again which will cause it to be
675                 # picked up by a worker thread and allow this thread
676                 # to return to run future heartbeats.
677                 if bundle_to_backup is not None:
678                     self.last_backup = now
679                     logger.info(
680                         f'=====> SCHEDULING BACKUP {bundle_to_backup} (score={best_score:.1f}) <====='
681                     )
682                     self.schedule_backup_for_bundle(bundle_to_backup)
683             finally:
684                 self.backup_lock.release()
685
686     def is_worker_available(self) -> bool:
687         return self.policy.is_worker_available()
688
689     def acquire_worker(
690         self, machine_to_avoid: str = None
691     ) -> Optional[RemoteWorkerRecord]:
692         return self.policy.acquire_worker(machine_to_avoid)
693
694     def find_available_worker_or_block(
695         self, machine_to_avoid: str = None
696     ) -> RemoteWorkerRecord:
697         with self.cv:
698             while not self.is_worker_available():
699                 self.cv.wait()
700             worker = self.acquire_worker(machine_to_avoid)
701             if worker is not None:
702                 return worker
703         msg = "We should never reach this point in the code"
704         logger.critical(msg)
705         raise Exception(msg)
706
707     def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
708         worker = bundle.worker
709         assert worker is not None
710         logger.debug(f'Released worker {worker}')
711         self.status.record_release_worker(
712             worker,
713             bundle.uuid,
714             was_cancelled,
715         )
716         with self.cv:
717             worker.count += 1
718             self.cv.notify()
719         self.adjust_task_count(-1)
720
721     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
722         with self.status.lock:
723             if bundle.is_cancelled.wait(timeout=0.0):
724                 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
725                 bundle.was_cancelled = True
726                 return True
727         return False
728
729     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
730         """Find a worker for bundle or block until one is available."""
731         self.adjust_task_count(+1)
732         uuid = bundle.uuid
733         hostname = bundle.hostname
734         avoid_machine = override_avoid_machine
735         is_original = bundle.src_bundle is None
736
737         # Try not to schedule a backup on the same host as the original.
738         if avoid_machine is None and bundle.src_bundle is not None:
739             avoid_machine = bundle.src_bundle.machine
740         worker = None
741         while worker is None:
742             worker = self.find_available_worker_or_block(avoid_machine)
743         assert worker
744
745         # Ok, found a worker.
746         bundle.worker = worker
747         machine = bundle.machine = worker.machine
748         username = bundle.username = worker.username
749         self.status.record_acquire_worker(worker, uuid)
750         logger.debug(f'{bundle}: Running bundle on {worker}...')
751
752         # Before we do any work, make sure the bundle is still viable.
753         # It may have been some time between when it was submitted and
754         # now due to lack of worker availability and someone else may
755         # have already finished it.
756         if self.check_if_cancelled(bundle):
757             try:
758                 return self.process_work_result(bundle)
759             except Exception as e:
760                 logger.warning(
761                     f'{bundle}: bundle says it\'s cancelled upfront but no results?!'
762                 )
763                 self.release_worker(bundle)
764                 if is_original:
765                     # Weird.  We are the original owner of this
766                     # bundle.  For it to have been cancelled, a backup
767                     # must have already started and completed before
768                     # we even for started.  Moreover, the backup says
769                     # it is done but we can't find the results it
770                     # should have copied over.  Reschedule the whole
771                     # thing.
772                     logger.exception(e)
773                     logger.error(
774                         f'{bundle}: We are the original owner thread and yet there are '
775                         + 'no results for this bundle.  This is unexpected and bad.'
776                     )
777                     return self.emergency_retry_nasty_bundle(bundle)
778                 else:
779                     # Expected(?).  We're a backup and our bundle is
780                     # cancelled before we even got started.  Something
781                     # went bad in process_work_result (I acutually don't
782                     # see what?) but probably not worth worrying
783                     # about.  Let the original thread worry about
784                     # either finding the results or complaining about
785                     # it.
786                     return None
787
788         # Send input code / data to worker machine if it's not local.
789         if hostname not in machine:
790             try:
791                 cmd = (
792                     f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
793                 )
794                 start_ts = time.time()
795                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
796                 run_silently(cmd)
797                 xfer_latency = time.time() - start_ts
798                 logger.debug(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.")
799             except Exception as e:
800                 self.release_worker(bundle)
801                 if is_original:
802                     # Weird.  We tried to copy the code to the worker and it failed...
803                     # And we're the original bundle.  We have to retry.
804                     logger.exception(e)
805                     logger.error(
806                         f"{bundle}: Failed to send instructions to the worker machine?! "
807                         + "This is not expected; we\'re the original bundle so this shouldn\'t "
808                         + "be a race condition.  Attempting an emergency retry..."
809                     )
810                     return self.emergency_retry_nasty_bundle(bundle)
811                 else:
812                     # This is actually expected; we're a backup.
813                     # There's a race condition where someone else
814                     # already finished the work and removed the source
815                     # code file before we could copy it.  No biggie.
816                     msg = f'{bundle}: Failed to send instructions to the worker machine... '
817                     msg += 'We\'re a backup and this may be caused by the original (or some '
818                     msg += 'other backup) already finishing this work.  Ignoring this.'
819                     logger.warning(msg)
820                     return None
821
822         # Kick off the work.  Note that if this fails we let
823         # wait_for_process deal with it.
824         self.status.record_processing_began(uuid)
825         cmd = (
826             f'{SSH} {bundle.username}@{bundle.machine} '
827             f'"source py38-venv/bin/activate &&'
828             f' /home/scott/lib/python_modules/remote_worker.py'
829             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
830         )
831         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
832         p = cmd_in_background(cmd, silent=True)
833         bundle.pid = p.pid
834         logger.debug(
835             f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.'
836         )
837         return self.wait_for_process(p, bundle, 0)
838
839     def wait_for_process(
840         self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
841     ) -> Any:
842         machine = bundle.machine
843         assert p
844         pid = p.pid
845         if depth > 3:
846             logger.error(
847                 f"I've gotten repeated errors waiting on this bundle; giving up on pid={pid}."
848             )
849             p.terminate()
850             self.release_worker(bundle)
851             return self.emergency_retry_nasty_bundle(bundle)
852
853         # Spin until either the ssh job we scheduled finishes the
854         # bundle or some backup worker signals that they finished it
855         # before we could.
856         while True:
857             try:
858                 p.wait(timeout=0.25)
859             except subprocess.TimeoutExpired:
860                 if self.check_if_cancelled(bundle):
861                     logger.info(
862                         f'{bundle}: looks like another worker finished bundle...'
863                     )
864                     break
865             else:
866                 logger.info(f"{bundle}: pid {pid} ({machine}) is finished!")
867                 p = None
868                 break
869
870         # If we get here we believe the bundle is done; either the ssh
871         # subprocess finished (hopefully successfully) or we noticed
872         # that some other worker seems to have completed the bundle
873         # and we're bailing out.
874         try:
875             ret = self.process_work_result(bundle)
876             if ret is not None and p is not None:
877                 p.terminate()
878             return ret
879
880         # Something went wrong; e.g. we could not copy the results
881         # back, cleanup after ourselves on the remote machine, or
882         # unpickle the results we got from the remove machine.  If we
883         # still have an active ssh subprocess, keep waiting on it.
884         # Otherwise, time for an emergency reschedule.
885         except Exception as e:
886             logger.exception(e)
887             logger.error(f'{bundle}: Something unexpected just happened...')
888             if p is not None:
889                 msg = f"{bundle}: Failed to wrap up \"done\" bundle, re-waiting on active ssh."
890                 logger.warning(msg)
891                 return self.wait_for_process(p, bundle, depth + 1)
892             else:
893                 self.release_worker(bundle)
894                 return self.emergency_retry_nasty_bundle(bundle)
895
896     def process_work_result(self, bundle: BundleDetails) -> Any:
897         with self.status.lock:
898             is_original = bundle.src_bundle is None
899             was_cancelled = bundle.was_cancelled
900             username = bundle.username
901             machine = bundle.machine
902             result_file = bundle.result_file
903             code_file = bundle.code_file
904
905             # Whether original or backup, if we finished first we must
906             # fetch the results if the computation happened on a
907             # remote machine.
908             bundle.end_ts = time.time()
909             if not was_cancelled:
910                 assert bundle.machine is not None
911                 if bundle.hostname not in bundle.machine:
912                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
913                     logger.info(
914                         f"{bundle}: Fetching results back from {username}@{machine} via {cmd}"
915                     )
916
917                     # If either of these throw they are handled in
918                     # wait_for_process.
919                     attempts = 0
920                     while True:
921                         try:
922                             run_silently(cmd)
923                         except Exception as e:
924                             attempts += 1
925                             if attempts >= 3:
926                                 raise (e)
927                         else:
928                             break
929
930                     run_silently(
931                         f'{SSH} {username}@{machine}'
932                         f' "/bin/rm -f {code_file} {result_file}"'
933                     )
934                     logger.debug(
935                         f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.'
936                     )
937                 dur = bundle.end_ts - bundle.start_ts
938                 self.histogram.add_item(dur)
939
940         # Only the original worker should unpickle the file contents
941         # though since it's the only one whose result matters.  The
942         # original is also the only job that may delete result_file
943         # from disk.  Note that the original may have been cancelled
944         # if one of the backups finished first; it still must read the
945         # result from disk.
946         if is_original:
947             logger.debug(f"{bundle}: Unpickling {result_file}.")
948             try:
949                 with open(result_file, 'rb') as rb:
950                     serialized = rb.read()
951                 result = cloudpickle.loads(serialized)
952             except Exception as e:
953                 logger.exception(e)
954                 msg = f'Failed to load {result_file}... this is bad news.'
955                 logger.critical(msg)
956                 self.release_worker(bundle)
957
958                 # Re-raise the exception; the code in wait_for_process may
959                 # decide to emergency_retry_nasty_bundle here.
960                 raise Exception(e)
961             logger.debug(f'Removing local (master) {code_file} and {result_file}.')
962             os.remove(f'{result_file}')
963             os.remove(f'{code_file}')
964
965             # Notify any backups that the original is done so they
966             # should stop ASAP.  Do this whether or not we
967             # finished first since there could be more than one
968             # backup.
969             if bundle.backup_bundles is not None:
970                 for backup in bundle.backup_bundles:
971                     logger.debug(
972                         f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled'
973                     )
974                     backup.is_cancelled.set()
975
976         # This is a backup job and, by now, we have already fetched
977         # the bundle results.
978         else:
979             # Backup results don't matter, they just need to leave the
980             # result file in the right place for their originals to
981             # read/unpickle later.
982             result = None
983
984             # Tell the original to stop if we finished first.
985             if not was_cancelled:
986                 orig_bundle = bundle.src_bundle
987                 assert orig_bundle
988                 logger.debug(
989                     f'{bundle}: Notifying original {orig_bundle.uuid} we beat them to it.'
990                 )
991                 orig_bundle.is_cancelled.set()
992         self.release_worker(bundle, was_cancelled=was_cancelled)
993         return result
994
995     def create_original_bundle(self, pickle, fname: str):
996         from string_utils import generate_uuid
997
998         uuid = generate_uuid(omit_dashes=True)
999         code_file = f'/tmp/{uuid}.code.bin'
1000         result_file = f'/tmp/{uuid}.result.bin'
1001
1002         logger.debug(f'Writing pickled code to {code_file}')
1003         with open(f'{code_file}', 'wb') as wb:
1004             wb.write(pickle)
1005
1006         bundle = BundleDetails(
1007             pickled_code=pickle,
1008             uuid=uuid,
1009             fname=fname,
1010             worker=None,
1011             username=None,
1012             machine=None,
1013             hostname=platform.node(),
1014             code_file=code_file,
1015             result_file=result_file,
1016             pid=0,
1017             start_ts=time.time(),
1018             end_ts=0.0,
1019             slower_than_local_p95=False,
1020             slower_than_global_p95=False,
1021             src_bundle=None,
1022             is_cancelled=threading.Event(),
1023             was_cancelled=False,
1024             backup_bundles=[],
1025             failure_count=0,
1026         )
1027         self.status.record_bundle_details(bundle)
1028         logger.debug(f'{bundle}: Created an original bundle')
1029         return bundle
1030
1031     def create_backup_bundle(self, src_bundle: BundleDetails):
1032         assert src_bundle.backup_bundles is not None
1033         n = len(src_bundle.backup_bundles)
1034         uuid = src_bundle.uuid + f'_backup#{n}'
1035
1036         backup_bundle = BundleDetails(
1037             pickled_code=src_bundle.pickled_code,
1038             uuid=uuid,
1039             fname=src_bundle.fname,
1040             worker=None,
1041             username=None,
1042             machine=None,
1043             hostname=src_bundle.hostname,
1044             code_file=src_bundle.code_file,
1045             result_file=src_bundle.result_file,
1046             pid=0,
1047             start_ts=time.time(),
1048             end_ts=0.0,
1049             slower_than_local_p95=False,
1050             slower_than_global_p95=False,
1051             src_bundle=src_bundle,
1052             is_cancelled=threading.Event(),
1053             was_cancelled=False,
1054             backup_bundles=None,  # backup backups not allowed
1055             failure_count=0,
1056         )
1057         src_bundle.backup_bundles.append(backup_bundle)
1058         self.status.record_bundle_details_already_locked(backup_bundle)
1059         logger.debug(f'{backup_bundle}: Created a backup bundle')
1060         return backup_bundle
1061
1062     def schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1063         assert self.status.lock.locked()
1064         assert src_bundle is not None
1065         backup_bundle = self.create_backup_bundle(src_bundle)
1066         logger.debug(
1067             f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
1068         )
1069         self._helper_executor.submit(self.launch, backup_bundle)
1070
1071         # Results from backups don't matter; if they finish first
1072         # they will move the result_file to this machine and let
1073         # the original pick them up and unpickle them.
1074
1075     def emergency_retry_nasty_bundle(
1076         self, bundle: BundleDetails
1077     ) -> Optional[fut.Future]:
1078         is_original = bundle.src_bundle is None
1079         bundle.worker = None
1080         avoid_last_machine = bundle.machine
1081         bundle.machine = None
1082         bundle.username = None
1083         bundle.failure_count += 1
1084         if is_original:
1085             retry_limit = 3
1086         else:
1087             retry_limit = 2
1088
1089         if bundle.failure_count > retry_limit:
1090             logger.error(
1091                 f'{bundle}: Tried this bundle too many times already ({retry_limit}x); giving up.'
1092             )
1093             if is_original:
1094                 raise RemoteExecutorException(
1095                     f'{bundle}: This bundle can\'t be completed despite several backups and retries'
1096                 )
1097             else:
1098                 logger.error(
1099                     f'{bundle}: At least it\'s only a backup; better luck with the others.'
1100                 )
1101             return None
1102         else:
1103             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1104             logger.warning(msg)
1105             warnings.warn(msg)
1106             return self.launch(bundle, avoid_last_machine)
1107
1108     @overrides
1109     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1110         if self.already_shutdown:
1111             raise Exception('Submitted work after shutdown.')
1112         pickle = make_cloud_pickle(function, *args, **kwargs)
1113         bundle = self.create_original_bundle(pickle, function.__name__)
1114         self.total_bundles_submitted += 1
1115         return self._helper_executor.submit(self.launch, bundle)
1116
1117     @overrides
1118     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
1119         if not self.already_shutdown:
1120             logging.debug(f'Shutting down RemoteExecutor {self.title}')
1121             self.heartbeat_stop_event.set()
1122             self.heartbeat_thread.join()
1123             self._helper_executor.shutdown(wait)
1124             if not quiet:
1125                 print(self.histogram.__repr__(label_formatter='%ds'))
1126             self.already_shutdown = True
1127
1128
1129 @singleton
1130 class DefaultExecutors(object):
1131     def __init__(self):
1132         self.thread_executor: Optional[ThreadExecutor] = None
1133         self.process_executor: Optional[ProcessExecutor] = None
1134         self.remote_executor: Optional[RemoteExecutor] = None
1135
1136     def ping(self, host) -> bool:
1137         logger.debug(f'RUN> ping -c 1 {host}')
1138         try:
1139             x = cmd_with_timeout(
1140                 f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0
1141             )
1142             return x == 0
1143         except Exception:
1144             return False
1145
1146     def thread_pool(self) -> ThreadExecutor:
1147         if self.thread_executor is None:
1148             self.thread_executor = ThreadExecutor()
1149         return self.thread_executor
1150
1151     def process_pool(self) -> ProcessExecutor:
1152         if self.process_executor is None:
1153             self.process_executor = ProcessExecutor()
1154         return self.process_executor
1155
1156     def remote_pool(self) -> RemoteExecutor:
1157         if self.remote_executor is None:
1158             logger.info('Looking for some helper machines...')
1159             pool: List[RemoteWorkerRecord] = []
1160             if self.ping('cheetah.house'):
1161                 logger.info('Found cheetah.house')
1162                 pool.append(
1163                     RemoteWorkerRecord(
1164                         username='scott',
1165                         machine='cheetah.house',
1166                         weight=30,
1167                         count=6,
1168                     ),
1169                 )
1170             if self.ping('meerkat.cabin'):
1171                 logger.info('Found meerkat.cabin')
1172                 pool.append(
1173                     RemoteWorkerRecord(
1174                         username='scott',
1175                         machine='meerkat.cabin',
1176                         weight=12,
1177                         count=2,
1178                     ),
1179                 )
1180             if self.ping('wannabe.house'):
1181                 logger.info('Found wannabe.house')
1182                 pool.append(
1183                     RemoteWorkerRecord(
1184                         username='scott',
1185                         machine='wannabe.house',
1186                         weight=25,
1187                         count=10,
1188                     ),
1189                 )
1190             if self.ping('puma.cabin'):
1191                 logger.info('Found puma.cabin')
1192                 pool.append(
1193                     RemoteWorkerRecord(
1194                         username='scott',
1195                         machine='puma.cabin',
1196                         weight=30,
1197                         count=6,
1198                     ),
1199                 )
1200             if self.ping('backup.house'):
1201                 logger.info('Found backup.house')
1202                 pool.append(
1203                     RemoteWorkerRecord(
1204                         username='scott',
1205                         machine='backup.house',
1206                         weight=8,
1207                         count=2,
1208                     ),
1209                 )
1210
1211             # The controller machine has a lot to do; go easy on it.
1212             for record in pool:
1213                 if record.machine == platform.node() and record.count > 1:
1214                     logger.info(f'Reducing workload for {record.machine}.')
1215                     record.count = 1
1216
1217             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1218             policy.register_worker_pool(pool)
1219             self.remote_executor = RemoteExecutor(pool, policy)
1220         return self.remote_executor
1221
1222     def shutdown(self) -> None:
1223         if self.thread_executor is not None:
1224             self.thread_executor.shutdown(wait=True, quiet=True)
1225             self.thread_executor = None
1226         if self.process_executor is not None:
1227             self.process_executor.shutdown(wait=True, quiet=True)
1228             self.process_executor = None
1229         if self.remote_executor is not None:
1230             self.remote_executor.shutdown(wait=True, quiet=True)
1231             self.remote_executor = None