Fix a recent bug in executors. Thread executor needs to return
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import concurrent.futures as fut
7 from collections import defaultdict
8 from dataclasses import dataclass
9 import logging
10 import numpy
11 import os
12 import platform
13 import random
14 import subprocess
15 import threading
16 import time
17 from typing import Any, Callable, Dict, List, Optional, Set
18 import warnings
19
20 import cloudpickle  # type: ignore
21 from overrides import overrides
22
23 from ansi import bg, fg, underline, reset
24 import argparse_utils
25 import config
26 from decorator_utils import singleton
27 from exec_utils import run_silently, cmd_in_background, cmd_with_timeout
28 import histogram as hist
29 from thread_utils import background_thread
30
31
32 logger = logging.getLogger(__name__)
33
34 parser = config.add_commandline_args(
35     f"Executors ({__file__})", "Args related to processing executors."
36 )
37 parser.add_argument(
38     '--executors_threadpool_size',
39     type=int,
40     metavar='#THREADS',
41     help='Number of threads in the default threadpool, leave unset for default',
42     default=None,
43 )
44 parser.add_argument(
45     '--executors_processpool_size',
46     type=int,
47     metavar='#PROCESSES',
48     help='Number of processes in the default processpool, leave unset for default',
49     default=None,
50 )
51 parser.add_argument(
52     '--executors_schedule_remote_backups',
53     default=True,
54     action=argparse_utils.ActionNoYes,
55     help='Should we schedule duplicative backup work if a remote bundle is slow',
56 )
57 parser.add_argument(
58     '--executors_max_bundle_failures',
59     type=int,
60     default=3,
61     metavar='#FAILURES',
62     help='Maximum number of failures before giving up on a bundle',
63 )
64
65 SSH = '/usr/bin/ssh -oForwardX11=no'
66 SCP = '/usr/bin/scp -C'
67
68
69 def make_cloud_pickle(fun, *args, **kwargs):
70     logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
71     return cloudpickle.dumps((fun, args, kwargs))
72
73
74 class BaseExecutor(ABC):
75     def __init__(self, *, title=''):
76         self.title = title
77         self.histogram = hist.SimpleHistogram(
78             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
79         )
80         self.task_count = 0
81
82     @abstractmethod
83     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
84         pass
85
86     @abstractmethod
87     def shutdown(self, wait: bool = True) -> None:
88         pass
89
90     def shutdown_if_idle(self) -> bool:
91         """Shutdown the executor and return True if the executor is idle
92         (i.e. there are no pending or active tasks).  Return False
93         otherwise.  Note: this should only be called by the launcher
94         process.
95
96         """
97         if self.task_count == 0:
98             self.shutdown()
99             return True
100         return False
101
102     def adjust_task_count(self, delta: int) -> None:
103         """Change the task count.  Note: do not call this method from a
104         worker, it should only be called by the launcher process /
105         thread / machine.
106
107         """
108         self.task_count += delta
109         logger.debug(f'Adjusted task count by {delta} to {self.task_count}')
110
111     def get_task_count(self) -> int:
112         """Change the task count.  Note: do not call this method from a
113         worker, it should only be called by the launcher process /
114         thread / machine.
115
116         """
117         return self.task_count
118
119
120 class ThreadExecutor(BaseExecutor):
121     def __init__(self, max_workers: Optional[int] = None):
122         super().__init__()
123         workers = None
124         if max_workers is not None:
125             workers = max_workers
126         elif 'executors_threadpool_size' in config.config:
127             workers = config.config['executors_threadpool_size']
128         logger.debug(f'Creating threadpool executor with {workers} workers')
129         self._thread_pool_executor = fut.ThreadPoolExecutor(
130             max_workers=workers, thread_name_prefix="thread_executor_helper"
131         )
132         self.already_shutdown = False
133
134     # This is run on a different thread; do not adjust task count here.
135     def run_local_bundle(self, fun, *args, **kwargs):
136         logger.debug(f"Running local bundle at {fun.__name__}")
137         result = fun(*args, **kwargs)
138         return result
139
140     @overrides
141     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
142         if self.already_shutdown:
143             raise Exception('Submitted work after shutdown.')
144         self.adjust_task_count(+1)
145         newargs = []
146         newargs.append(function)
147         for arg in args:
148             newargs.append(arg)
149         start = time.time()
150         result = self._thread_pool_executor.submit(
151             self.run_local_bundle, *newargs, **kwargs
152         )
153         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
154         result.add_done_callback(lambda _: self.adjust_task_count(-1))
155         return result
156
157     @overrides
158     def shutdown(self, wait=True) -> None:
159         if not self.already_shutdown:
160             logger.debug(f'Shutting down threadpool executor {self.title}')
161             print(self.histogram)
162             self._thread_pool_executor.shutdown(wait)
163             self.already_shutdown = True
164
165
166 class ProcessExecutor(BaseExecutor):
167     def __init__(self, max_workers=None):
168         super().__init__()
169         workers = None
170         if max_workers is not None:
171             workers = max_workers
172         elif 'executors_processpool_size' in config.config:
173             workers = config.config['executors_processpool_size']
174         logger.debug(f'Creating processpool executor with {workers} workers.')
175         self._process_executor = fut.ProcessPoolExecutor(
176             max_workers=workers,
177         )
178         self.already_shutdown = False
179
180     # This is run in another process; do not adjust task count here.
181     def run_cloud_pickle(self, pickle):
182         fun, args, kwargs = cloudpickle.loads(pickle)
183         logger.debug(f"Running pickled bundle at {fun.__name__}")
184         result = fun(*args, **kwargs)
185         return result
186
187     @overrides
188     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
189         if self.already_shutdown:
190             raise Exception('Submitted work after shutdown.')
191         start = time.time()
192         self.adjust_task_count(+1)
193         pickle = make_cloud_pickle(function, *args, **kwargs)
194         result = self._process_executor.submit(self.run_cloud_pickle, pickle)
195         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
196         result.add_done_callback(lambda _: self.adjust_task_count(-1))
197         return result
198
199     @overrides
200     def shutdown(self, wait=True) -> None:
201         if not self.already_shutdown:
202             logger.debug(f'Shutting down processpool executor {self.title}')
203             self._process_executor.shutdown(wait)
204             print(self.histogram)
205             self.already_shutdown = True
206
207     def __getstate__(self):
208         state = self.__dict__.copy()
209         state['_process_executor'] = None
210         return state
211
212
213 class RemoteExecutorException(Exception):
214     """Thrown when a bundle cannot be executed despite several retries."""
215
216     pass
217
218
219 @dataclass
220 class RemoteWorkerRecord:
221     username: str
222     machine: str
223     weight: int
224     count: int
225
226     def __hash__(self):
227         return hash((self.username, self.machine))
228
229     def __repr__(self):
230         return f'{self.username}@{self.machine}'
231
232
233 @dataclass
234 class BundleDetails:
235     pickled_code: bytes
236     uuid: str
237     fname: str
238     worker: Optional[RemoteWorkerRecord]
239     username: Optional[str]
240     machine: Optional[str]
241     hostname: str
242     code_file: str
243     result_file: str
244     pid: int
245     start_ts: float
246     end_ts: float
247     slower_than_local_p95: bool
248     slower_than_global_p95: bool
249     src_bundle: BundleDetails
250     is_cancelled: threading.Event
251     was_cancelled: bool
252     backup_bundles: Optional[List[BundleDetails]]
253     failure_count: int
254
255     def __repr__(self):
256         uuid = self.uuid
257         if uuid[-9:-2] == '_backup':
258             uuid = uuid[:-9]
259             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
260         else:
261             suffix = uuid[-6:]
262
263         colorz = [
264             fg('violet red'),
265             fg('red'),
266             fg('orange'),
267             fg('peach orange'),
268             fg('yellow'),
269             fg('marigold yellow'),
270             fg('green yellow'),
271             fg('tea green'),
272             fg('cornflower blue'),
273             fg('turquoise blue'),
274             fg('tropical blue'),
275             fg('lavender purple'),
276             fg('medium purple'),
277         ]
278         c = colorz[int(uuid[-2:], 16) % len(colorz)]
279         fname = self.fname if self.fname is not None else 'nofname'
280         machine = self.machine if self.machine is not None else 'nomachine'
281         return f'{c}{suffix}/{fname}/{machine}{reset()}'
282
283
284 class RemoteExecutorStatus:
285     def __init__(self, total_worker_count: int) -> None:
286         self.worker_count: int = total_worker_count
287         self.known_workers: Set[RemoteWorkerRecord] = set()
288         self.start_time: float = time.time()
289         self.start_per_bundle: Dict[str, float] = defaultdict(float)
290         self.end_per_bundle: Dict[str, float] = defaultdict(float)
291         self.finished_bundle_timings_per_worker: Dict[
292             RemoteWorkerRecord, List[float]
293         ] = {}
294         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
295         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
296         self.finished_bundle_timings: List[float] = []
297         self.last_periodic_dump: Optional[float] = None
298         self.total_bundles_submitted: int = 0
299
300         # Protects reads and modification using self.  Also used
301         # as a memory fence for modifications to bundle.
302         self.lock: threading.Lock = threading.Lock()
303
304     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
305         with self.lock:
306             self.record_acquire_worker_already_locked(worker, uuid)
307
308     def record_acquire_worker_already_locked(
309         self, worker: RemoteWorkerRecord, uuid: str
310     ) -> None:
311         assert self.lock.locked()
312         self.known_workers.add(worker)
313         self.start_per_bundle[uuid] = None
314         x = self.in_flight_bundles_by_worker.get(worker, set())
315         x.add(uuid)
316         self.in_flight_bundles_by_worker[worker] = x
317
318     def record_bundle_details(self, details: BundleDetails) -> None:
319         with self.lock:
320             self.record_bundle_details_already_locked(details)
321
322     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
323         assert self.lock.locked()
324         self.bundle_details_by_uuid[details.uuid] = details
325
326     def record_release_worker(
327         self,
328         worker: RemoteWorkerRecord,
329         uuid: str,
330         was_cancelled: bool,
331     ) -> None:
332         with self.lock:
333             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
334
335     def record_release_worker_already_locked(
336         self,
337         worker: RemoteWorkerRecord,
338         uuid: str,
339         was_cancelled: bool,
340     ) -> None:
341         assert self.lock.locked()
342         ts = time.time()
343         self.end_per_bundle[uuid] = ts
344         self.in_flight_bundles_by_worker[worker].remove(uuid)
345         if not was_cancelled:
346             bundle_latency = ts - self.start_per_bundle[uuid]
347             x = self.finished_bundle_timings_per_worker.get(worker, list())
348             x.append(bundle_latency)
349             self.finished_bundle_timings_per_worker[worker] = x
350             self.finished_bundle_timings.append(bundle_latency)
351
352     def record_processing_began(self, uuid: str):
353         with self.lock:
354             self.start_per_bundle[uuid] = time.time()
355
356     def total_in_flight(self) -> int:
357         assert self.lock.locked()
358         total_in_flight = 0
359         for worker in self.known_workers:
360             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
361         return total_in_flight
362
363     def total_idle(self) -> int:
364         assert self.lock.locked()
365         return self.worker_count - self.total_in_flight()
366
367     def __repr__(self):
368         assert self.lock.locked()
369         ts = time.time()
370         total_finished = len(self.finished_bundle_timings)
371         total_in_flight = self.total_in_flight()
372         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
373         qall = None
374         if len(self.finished_bundle_timings) > 1:
375             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
376             ret += (
377                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, total={ts-self.start_time:.1f}s, '
378                 f'✅={total_finished}/{self.total_bundles_submitted}, '
379                 f'💻n={total_in_flight}/{self.worker_count}\n'
380             )
381         else:
382             ret += (
383                 f'⏱={ts-self.start_time:.1f}s, '
384                 f'✅={total_finished}/{self.total_bundles_submitted}, '
385                 f'💻n={total_in_flight}/{self.worker_count}\n'
386             )
387
388         for worker in self.known_workers:
389             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
390             timings = self.finished_bundle_timings_per_worker.get(worker, [])
391             count = len(timings)
392             qworker = None
393             if count > 1:
394                 qworker = numpy.quantile(timings, [0.5, 0.95])
395                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
396             else:
397                 ret += '\n'
398             if count > 0:
399                 ret += f'    ...finished {count} total bundle(s) so far\n'
400             in_flight = len(self.in_flight_bundles_by_worker[worker])
401             if in_flight > 0:
402                 ret += f'    ...{in_flight} bundles currently in flight:\n'
403                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
404                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
405                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
406                     if self.start_per_bundle[bundle_uuid] is not None:
407                         sec = ts - self.start_per_bundle[bundle_uuid]
408                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
409                     else:
410                         ret += f'       {details} setting up / copying data...'
411                         sec = 0.0
412
413                     if qworker is not None:
414                         if sec > qworker[1]:
415                             ret += f'{bg("red")}>💻p95{reset()} '
416                             if details is not None:
417                                 details.slower_than_local_p95 = True
418                         else:
419                             if details is not None:
420                                 details.slower_than_local_p95 = False
421
422                     if qall is not None:
423                         if sec > qall[1]:
424                             ret += f'{bg("red")}>∀p95{reset()} '
425                             if details is not None:
426                                 details.slower_than_global_p95 = True
427                         else:
428                             details.slower_than_global_p95 = False
429                     ret += '\n'
430         return ret
431
432     def periodic_dump(self, total_bundles_submitted: int) -> None:
433         assert self.lock.locked()
434         self.total_bundles_submitted = total_bundles_submitted
435         ts = time.time()
436         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
437             print(self)
438             self.last_periodic_dump = ts
439
440
441 class RemoteWorkerSelectionPolicy(ABC):
442     def register_worker_pool(self, workers):
443         self.workers = workers
444
445     @abstractmethod
446     def is_worker_available(self) -> bool:
447         pass
448
449     @abstractmethod
450     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
451         pass
452
453
454 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
455     @overrides
456     def is_worker_available(self) -> bool:
457         for worker in self.workers:
458             if worker.count > 0:
459                 return True
460         return False
461
462     @overrides
463     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
464         grabbag = []
465         for worker in self.workers:
466             if worker.machine != machine_to_avoid:
467                 if worker.count > 0:
468                     for _ in range(worker.count * worker.weight):
469                         grabbag.append(worker)
470
471         if len(grabbag) == 0:
472             logger.debug(
473                 f'There are no available workers that avoid {machine_to_avoid}...'
474             )
475             for worker in self.workers:
476                 if worker.count > 0:
477                     for _ in range(worker.count * worker.weight):
478                         grabbag.append(worker)
479
480         if len(grabbag) == 0:
481             logger.warning('There are no available workers?!')
482             return None
483
484         worker = random.sample(grabbag, 1)[0]
485         assert worker.count > 0
486         worker.count -= 1
487         logger.debug(f'Chose worker {worker}')
488         return worker
489
490
491 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
492     def __init__(self) -> None:
493         self.index = 0
494
495     @overrides
496     def is_worker_available(self) -> bool:
497         for worker in self.workers:
498             if worker.count > 0:
499                 return True
500         return False
501
502     @overrides
503     def acquire_worker(
504         self, machine_to_avoid: str = None
505     ) -> Optional[RemoteWorkerRecord]:
506         x = self.index
507         while True:
508             worker = self.workers[x]
509             if worker.count > 0:
510                 worker.count -= 1
511                 x += 1
512                 if x >= len(self.workers):
513                     x = 0
514                 self.index = x
515                 logger.debug(f'Selected worker {worker}')
516                 return worker
517             x += 1
518             if x >= len(self.workers):
519                 x = 0
520             if x == self.index:
521                 msg = 'Unexpectedly could not find a worker, retrying...'
522                 logger.warning(msg)
523                 return None
524
525
526 class RemoteExecutor(BaseExecutor):
527     def __init__(
528         self,
529         workers: List[RemoteWorkerRecord],
530         policy: RemoteWorkerSelectionPolicy,
531     ) -> None:
532         super().__init__()
533         self.workers = workers
534         self.policy = policy
535         self.worker_count = 0
536         for worker in self.workers:
537             self.worker_count += worker.count
538         if self.worker_count <= 0:
539             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
540             logger.critical(msg)
541             raise RemoteExecutorException(msg)
542         self.policy.register_worker_pool(self.workers)
543         self.cv = threading.Condition()
544         logger.debug(
545             f'Creating {self.worker_count} local threads, one per remote worker.'
546         )
547         self._helper_executor = fut.ThreadPoolExecutor(
548             thread_name_prefix="remote_executor_helper",
549             max_workers=self.worker_count,
550         )
551         self.status = RemoteExecutorStatus(self.worker_count)
552         self.total_bundles_submitted = 0
553         self.backup_lock = threading.Lock()
554         self.last_backup = None
555         (
556             self.heartbeat_thread,
557             self.heartbeat_stop_event,
558         ) = self.run_periodic_heartbeat()
559         self.already_shutdown = False
560
561     @background_thread
562     def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
563         while not stop_event.is_set():
564             time.sleep(5.0)
565             logger.debug('Running periodic heartbeat code...')
566             self.heartbeat()
567         logger.debug('Periodic heartbeat thread shutting down.')
568
569     def heartbeat(self) -> None:
570         # Note: this is invoked on a background thread, not an
571         # executor thread.  Be careful what you do with it b/c it
572         # needs to get back and dump status again periodically.
573         with self.status.lock:
574             self.status.periodic_dump(self.total_bundles_submitted)
575
576             # Look for bundles to reschedule via executor.submit
577             if config.config['executors_schedule_remote_backups']:
578                 self.maybe_schedule_backup_bundles()
579
580     def maybe_schedule_backup_bundles(self):
581         assert self.status.lock.locked()
582         num_done = len(self.status.finished_bundle_timings)
583         num_idle_workers = self.worker_count - self.task_count
584         now = time.time()
585         if (
586             num_done > 2
587             and num_idle_workers > 1
588             and (self.last_backup is None or (now - self.last_backup > 9.0))
589             and self.backup_lock.acquire(blocking=False)
590         ):
591             try:
592                 assert self.backup_lock.locked()
593
594                 bundle_to_backup = None
595                 best_score = None
596                 for (
597                     worker,
598                     bundle_uuids,
599                 ) in self.status.in_flight_bundles_by_worker.items():
600
601                     # Prefer to schedule backups of bundles running on
602                     # slower machines.
603                     base_score = 0
604                     for record in self.workers:
605                         if worker.machine == record.machine:
606                             base_score = float(record.weight)
607                             base_score = 1.0 / base_score
608                             base_score *= 200.0
609                             base_score = int(base_score)
610                             break
611
612                     for uuid in bundle_uuids:
613                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
614                         if (
615                             bundle is not None
616                             and bundle.src_bundle is None
617                             and bundle.backup_bundles is not None
618                         ):
619                             score = base_score
620
621                             # Schedule backups of bundles running
622                             # longer; especially those that are
623                             # unexpectedly slow.
624                             start_ts = self.status.start_per_bundle[uuid]
625                             if start_ts is not None:
626                                 runtime = now - start_ts
627                                 score += runtime
628                                 logger.debug(
629                                     f'score[{bundle}] => {score}  # latency boost'
630                                 )
631
632                                 if bundle.slower_than_local_p95:
633                                     score += runtime / 2
634                                     logger.debug(
635                                         f'score[{bundle}] => {score}  # >worker p95'
636                                     )
637
638                                 if bundle.slower_than_global_p95:
639                                     score += runtime / 4
640                                     logger.debug(
641                                         f'score[{bundle}] => {score}  # >global p95'
642                                     )
643
644                             # Prefer backups of bundles that don't
645                             # have backups already.
646                             backup_count = len(bundle.backup_bundles)
647                             if backup_count == 0:
648                                 score *= 2
649                             elif backup_count == 1:
650                                 score /= 2
651                             elif backup_count == 2:
652                                 score /= 8
653                             else:
654                                 score = 0
655                             logger.debug(
656                                 f'score[{bundle}] => {score}  # {backup_count} dup backup factor'
657                             )
658
659                             if score != 0 and (
660                                 best_score is None or score > best_score
661                             ):
662                                 bundle_to_backup = bundle
663                                 assert bundle is not None
664                                 assert bundle.backup_bundles is not None
665                                 assert bundle.src_bundle is None
666                                 best_score = score
667
668                 # Note: this is all still happening on the heartbeat
669                 # runner thread.  That's ok because
670                 # schedule_backup_for_bundle uses the executor to
671                 # submit the bundle again which will cause it to be
672                 # picked up by a worker thread and allow this thread
673                 # to return to run future heartbeats.
674                 if bundle_to_backup is not None:
675                     self.last_backup = now
676                     logger.info(
677                         f'=====> SCHEDULING BACKUP {bundle_to_backup} (score={best_score:.1f}) <====='
678                     )
679                     self.schedule_backup_for_bundle(bundle_to_backup)
680             finally:
681                 self.backup_lock.release()
682
683     def is_worker_available(self) -> bool:
684         return self.policy.is_worker_available()
685
686     def acquire_worker(
687         self, machine_to_avoid: str = None
688     ) -> Optional[RemoteWorkerRecord]:
689         return self.policy.acquire_worker(machine_to_avoid)
690
691     def find_available_worker_or_block(
692         self, machine_to_avoid: str = None
693     ) -> RemoteWorkerRecord:
694         with self.cv:
695             while not self.is_worker_available():
696                 self.cv.wait()
697             worker = self.acquire_worker(machine_to_avoid)
698             if worker is not None:
699                 return worker
700         msg = "We should never reach this point in the code"
701         logger.critical(msg)
702         raise Exception(msg)
703
704     def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
705         worker = bundle.worker
706         assert worker is not None
707         logger.debug(f'Released worker {worker}')
708         self.status.record_release_worker(
709             worker,
710             bundle.uuid,
711             was_cancelled,
712         )
713         with self.cv:
714             worker.count += 1
715             self.cv.notify()
716         self.adjust_task_count(-1)
717
718     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
719         with self.status.lock:
720             if bundle.is_cancelled.wait(timeout=0.0):
721                 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
722                 bundle.was_cancelled = True
723                 return True
724         return False
725
726     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
727         """Find a worker for bundle or block until one is available."""
728         self.adjust_task_count(+1)
729         uuid = bundle.uuid
730         hostname = bundle.hostname
731         avoid_machine = override_avoid_machine
732         is_original = bundle.src_bundle is None
733
734         # Try not to schedule a backup on the same host as the original.
735         if avoid_machine is None and bundle.src_bundle is not None:
736             avoid_machine = bundle.src_bundle.machine
737         worker = None
738         while worker is None:
739             worker = self.find_available_worker_or_block(avoid_machine)
740         assert worker
741
742         # Ok, found a worker.
743         bundle.worker = worker
744         machine = bundle.machine = worker.machine
745         username = bundle.username = worker.username
746         self.status.record_acquire_worker(worker, uuid)
747         logger.debug(f'{bundle}: Running bundle on {worker}...')
748
749         # Before we do any work, make sure the bundle is still viable.
750         # It may have been some time between when it was submitted and
751         # now due to lack of worker availability and someone else may
752         # have already finished it.
753         if self.check_if_cancelled(bundle):
754             try:
755                 return self.process_work_result(bundle)
756             except Exception as e:
757                 logger.warning(
758                     f'{bundle}: bundle says it\'s cancelled upfront but no results?!'
759                 )
760                 self.release_worker(bundle)
761                 if is_original:
762                     # Weird.  We are the original owner of this
763                     # bundle.  For it to have been cancelled, a backup
764                     # must have already started and completed before
765                     # we even for started.  Moreover, the backup says
766                     # it is done but we can't find the results it
767                     # should have copied over.  Reschedule the whole
768                     # thing.
769                     logger.exception(e)
770                     logger.error(
771                         f'{bundle}: We are the original owner thread and yet there are '
772                         + 'no results for this bundle.  This is unexpected and bad.'
773                     )
774                     return self.emergency_retry_nasty_bundle(bundle)
775                 else:
776                     # Expected(?).  We're a backup and our bundle is
777                     # cancelled before we even got started.  Something
778                     # went bad in process_work_result (I acutually don't
779                     # see what?) but probably not worth worrying
780                     # about.  Let the original thread worry about
781                     # either finding the results or complaining about
782                     # it.
783                     return None
784
785         # Send input code / data to worker machine if it's not local.
786         if hostname not in machine:
787             try:
788                 cmd = (
789                     f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
790                 )
791                 start_ts = time.time()
792                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
793                 run_silently(cmd)
794                 xfer_latency = time.time() - start_ts
795                 logger.debug(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.")
796             except Exception as e:
797                 self.release_worker(bundle)
798                 if is_original:
799                     # Weird.  We tried to copy the code to the worker and it failed...
800                     # And we're the original bundle.  We have to retry.
801                     logger.exception(e)
802                     logger.error(
803                         f"{bundle}: Failed to send instructions to the worker machine?! "
804                         + "This is not expected; we\'re the original bundle so this shouldn\'t "
805                         + "be a race condition.  Attempting an emergency retry..."
806                     )
807                     return self.emergency_retry_nasty_bundle(bundle)
808                 else:
809                     # This is actually expected; we're a backup.
810                     # There's a race condition where someone else
811                     # already finished the work and removed the source
812                     # code file before we could copy it.  No biggie.
813                     msg = f'{bundle}: Failed to send instructions to the worker machine... '
814                     msg += 'We\'re a backup and this may be caused by the original (or some '
815                     msg += 'other backup) already finishing this work.  Ignoring this.'
816                     logger.warning(msg)
817                     return None
818
819         # Kick off the work.  Note that if this fails we let
820         # wait_for_process deal with it.
821         self.status.record_processing_began(uuid)
822         cmd = (
823             f'{SSH} {bundle.username}@{bundle.machine} '
824             f'"source py38-venv/bin/activate &&'
825             f' /home/scott/lib/python_modules/remote_worker.py'
826             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
827         )
828         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
829         p = cmd_in_background(cmd, silent=True)
830         bundle.pid = p.pid
831         logger.debug(
832             f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.'
833         )
834         return self.wait_for_process(p, bundle, 0)
835
836     def wait_for_process(
837         self, p: subprocess.Popen, bundle: BundleDetails, depth: int
838     ) -> Any:
839         machine = bundle.machine
840         pid = p.pid
841         if depth > 3:
842             logger.error(
843                 f"I've gotten repeated errors waiting on this bundle; giving up on pid={pid}."
844             )
845             p.terminate()
846             self.release_worker(bundle)
847             return self.emergency_retry_nasty_bundle(bundle)
848
849         # Spin until either the ssh job we scheduled finishes the
850         # bundle or some backup worker signals that they finished it
851         # before we could.
852         while True:
853             try:
854                 p.wait(timeout=0.25)
855             except subprocess.TimeoutExpired:
856                 if self.check_if_cancelled(bundle):
857                     logger.info(
858                         f'{bundle}: looks like another worker finished bundle...'
859                     )
860                     break
861             else:
862                 logger.info(f"{bundle}: pid {pid} ({machine}) is finished!")
863                 p = None
864                 break
865
866         # If we get here we believe the bundle is done; either the ssh
867         # subprocess finished (hopefully successfully) or we noticed
868         # that some other worker seems to have completed the bundle
869         # and we're bailing out.
870         try:
871             ret = self.process_work_result(bundle)
872             if ret is not None and p is not None:
873                 p.terminate()
874             return ret
875
876         # Something went wrong; e.g. we could not copy the results
877         # back, cleanup after ourselves on the remote machine, or
878         # unpickle the results we got from the remove machine.  If we
879         # still have an active ssh subprocess, keep waiting on it.
880         # Otherwise, time for an emergency reschedule.
881         except Exception as e:
882             logger.exception(e)
883             logger.error(f'{bundle}: Something unexpected just happened...')
884             if p is not None:
885                 msg = f"{bundle}: Failed to wrap up \"done\" bundle, re-waiting on active ssh."
886                 logger.warning(msg)
887                 return self.wait_for_process(p, bundle, depth + 1)
888             else:
889                 self.release_worker(bundle)
890                 return self.emergency_retry_nasty_bundle(bundle)
891
892     def process_work_result(self, bundle: BundleDetails) -> Any:
893         with self.status.lock:
894             is_original = bundle.src_bundle is None
895             was_cancelled = bundle.was_cancelled
896             username = bundle.username
897             machine = bundle.machine
898             result_file = bundle.result_file
899             code_file = bundle.code_file
900
901             # Whether original or backup, if we finished first we must
902             # fetch the results if the computation happened on a
903             # remote machine.
904             bundle.end_ts = time.time()
905             if not was_cancelled:
906                 assert bundle.machine is not None
907                 if bundle.hostname not in bundle.machine:
908                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
909                     logger.info(
910                         f"{bundle}: Fetching results back from {username}@{machine} via {cmd}"
911                     )
912
913                     # If either of these throw they are handled in
914                     # wait_for_process.
915                     attempts = 0
916                     while True:
917                         try:
918                             run_silently(cmd)
919                         except Exception as e:
920                             attempts += 1
921                             if attempts >= 3:
922                                 raise (e)
923                         else:
924                             break
925
926                     run_silently(
927                         f'{SSH} {username}@{machine}'
928                         f' "/bin/rm -f {code_file} {result_file}"'
929                     )
930                     logger.debug(
931                         f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.'
932                     )
933                 dur = bundle.end_ts - bundle.start_ts
934                 self.histogram.add_item(dur)
935
936         # Only the original worker should unpickle the file contents
937         # though since it's the only one whose result matters.  The
938         # original is also the only job that may delete result_file
939         # from disk.  Note that the original may have been cancelled
940         # if one of the backups finished first; it still must read the
941         # result from disk.
942         if is_original:
943             logger.debug(f"{bundle}: Unpickling {result_file}.")
944             try:
945                 with open(result_file, 'rb') as rb:
946                     serialized = rb.read()
947                 result = cloudpickle.loads(serialized)
948             except Exception as e:
949                 logger.exception(e)
950                 msg = f'Failed to load {result_file}... this is bad news.'
951                 logger.critical(msg)
952                 self.release_worker(bundle)
953
954                 # Re-raise the exception; the code in wait_for_process may
955                 # decide to emergency_retry_nasty_bundle here.
956                 raise Exception(e)
957             logger.debug(f'Removing local (master) {code_file} and {result_file}.')
958             os.remove(f'{result_file}')
959             os.remove(f'{code_file}')
960
961             # Notify any backups that the original is done so they
962             # should stop ASAP.  Do this whether or not we
963             # finished first since there could be more than one
964             # backup.
965             if bundle.backup_bundles is not None:
966                 for backup in bundle.backup_bundles:
967                     logger.debug(
968                         f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled'
969                     )
970                     backup.is_cancelled.set()
971
972         # This is a backup job and, by now, we have already fetched
973         # the bundle results.
974         else:
975             # Backup results don't matter, they just need to leave the
976             # result file in the right place for their originals to
977             # read/unpickle later.
978             result = None
979
980             # Tell the original to stop if we finished first.
981             if not was_cancelled:
982                 logger.debug(
983                     f'{bundle}: Notifying original {bundle.src_bundle.uuid} we beat them to it.'
984                 )
985                 bundle.src_bundle.is_cancelled.set()
986         self.release_worker(bundle, was_cancelled=was_cancelled)
987         return result
988
989     def create_original_bundle(self, pickle, fname: str):
990         from string_utils import generate_uuid
991
992         uuid = generate_uuid(omit_dashes=True)
993         code_file = f'/tmp/{uuid}.code.bin'
994         result_file = f'/tmp/{uuid}.result.bin'
995
996         logger.debug(f'Writing pickled code to {code_file}')
997         with open(f'{code_file}', 'wb') as wb:
998             wb.write(pickle)
999
1000         bundle = BundleDetails(
1001             pickled_code=pickle,
1002             uuid=uuid,
1003             fname=fname,
1004             worker=None,
1005             username=None,
1006             machine=None,
1007             hostname=platform.node(),
1008             code_file=code_file,
1009             result_file=result_file,
1010             pid=0,
1011             start_ts=time.time(),
1012             end_ts=0.0,
1013             slower_than_local_p95=False,
1014             slower_than_global_p95=False,
1015             src_bundle=None,
1016             is_cancelled=threading.Event(),
1017             was_cancelled=False,
1018             backup_bundles=[],
1019             failure_count=0,
1020         )
1021         self.status.record_bundle_details(bundle)
1022         logger.debug(f'{bundle}: Created an original bundle')
1023         return bundle
1024
1025     def create_backup_bundle(self, src_bundle: BundleDetails):
1026         assert src_bundle.backup_bundles is not None
1027         n = len(src_bundle.backup_bundles)
1028         uuid = src_bundle.uuid + f'_backup#{n}'
1029
1030         backup_bundle = BundleDetails(
1031             pickled_code=src_bundle.pickled_code,
1032             uuid=uuid,
1033             fname=src_bundle.fname,
1034             worker=None,
1035             username=None,
1036             machine=None,
1037             hostname=src_bundle.hostname,
1038             code_file=src_bundle.code_file,
1039             result_file=src_bundle.result_file,
1040             pid=0,
1041             start_ts=time.time(),
1042             end_ts=0.0,
1043             slower_than_local_p95=False,
1044             slower_than_global_p95=False,
1045             src_bundle=src_bundle,
1046             is_cancelled=threading.Event(),
1047             was_cancelled=False,
1048             backup_bundles=None,  # backup backups not allowed
1049             failure_count=0,
1050         )
1051         src_bundle.backup_bundles.append(backup_bundle)
1052         self.status.record_bundle_details_already_locked(backup_bundle)
1053         logger.debug(f'{backup_bundle}: Created a backup bundle')
1054         return backup_bundle
1055
1056     def schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1057         assert self.status.lock.locked()
1058         assert src_bundle is not None
1059         backup_bundle = self.create_backup_bundle(src_bundle)
1060         logger.debug(
1061             f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
1062         )
1063         self._helper_executor.submit(self.launch, backup_bundle)
1064
1065         # Results from backups don't matter; if they finish first
1066         # they will move the result_file to this machine and let
1067         # the original pick them up and unpickle them.
1068
1069     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> fut.Future:
1070         is_original = bundle.src_bundle is None
1071         bundle.worker = None
1072         avoid_last_machine = bundle.machine
1073         bundle.machine = None
1074         bundle.username = None
1075         bundle.failure_count += 1
1076         if is_original:
1077             retry_limit = 3
1078         else:
1079             retry_limit = 2
1080
1081         if bundle.failure_count > retry_limit:
1082             logger.error(
1083                 f'{bundle}: Tried this bundle too many times already ({retry_limit}x); giving up.'
1084             )
1085             if is_original:
1086                 raise RemoteExecutorException(
1087                     f'{bundle}: This bundle can\'t be completed despite several backups and retries'
1088                 )
1089             else:
1090                 logger.error(
1091                     f'{bundle}: At least it\'s only a backup; better luck with the others.'
1092                 )
1093             return None
1094         else:
1095             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1096             logger.warning(msg)
1097             warnings.warn(msg)
1098             return self.launch(bundle, avoid_last_machine)
1099
1100     @overrides
1101     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1102         if self.already_shutdown:
1103             raise Exception('Submitted work after shutdown.')
1104         pickle = make_cloud_pickle(function, *args, **kwargs)
1105         bundle = self.create_original_bundle(pickle, function.__name__)
1106         self.total_bundles_submitted += 1
1107         return self._helper_executor.submit(self.launch, bundle)
1108
1109     @overrides
1110     def shutdown(self, wait=True) -> None:
1111         if not self.already_shutdown:
1112             logging.debug(f'Shutting down RemoteExecutor {self.title}')
1113             self.heartbeat_stop_event.set()
1114             self.heartbeat_thread.join()
1115             self._helper_executor.shutdown(wait)
1116             print(self.histogram)
1117             self.already_shutdown = True
1118
1119
1120 @singleton
1121 class DefaultExecutors(object):
1122     def __init__(self):
1123         self.thread_executor: Optional[ThreadExecutor] = None
1124         self.process_executor: Optional[ProcessExecutor] = None
1125         self.remote_executor: Optional[RemoteExecutor] = None
1126
1127     def ping(self, host) -> bool:
1128         logger.debug(f'RUN> ping -c 1 {host}')
1129         try:
1130             x = cmd_with_timeout(
1131                 f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0
1132             )
1133             return x == 0
1134         except Exception:
1135             return False
1136
1137     def thread_pool(self) -> ThreadExecutor:
1138         if self.thread_executor is None:
1139             self.thread_executor = ThreadExecutor()
1140         return self.thread_executor
1141
1142     def process_pool(self) -> ProcessExecutor:
1143         if self.process_executor is None:
1144             self.process_executor = ProcessExecutor()
1145         return self.process_executor
1146
1147     def remote_pool(self) -> RemoteExecutor:
1148         if self.remote_executor is None:
1149             logger.info('Looking for some helper machines...')
1150             pool: List[RemoteWorkerRecord] = []
1151             if self.ping('cheetah.house'):
1152                 logger.info('Found cheetah.house')
1153                 pool.append(
1154                     RemoteWorkerRecord(
1155                         username='scott',
1156                         machine='cheetah.house',
1157                         weight=30,
1158                         count=6,
1159                     ),
1160                 )
1161             if self.ping('meerkat.cabin'):
1162                 logger.info('Found meerkat.cabin')
1163                 pool.append(
1164                     RemoteWorkerRecord(
1165                         username='scott',
1166                         machine='meerkat.cabin',
1167                         weight=12,
1168                         count=2,
1169                     ),
1170                 )
1171             if self.ping('wannabe.house'):
1172                 logger.info('Found wannabe.house')
1173                 pool.append(
1174                     RemoteWorkerRecord(
1175                         username='scott',
1176                         machine='wannabe.house',
1177                         weight=25,
1178                         count=10,
1179                     ),
1180                 )
1181             if self.ping('puma.cabin'):
1182                 logger.info('Found puma.cabin')
1183                 pool.append(
1184                     RemoteWorkerRecord(
1185                         username='scott',
1186                         machine='puma.cabin',
1187                         weight=30,
1188                         count=6,
1189                     ),
1190                 )
1191             if self.ping('backup.house'):
1192                 logger.info('Found backup.house')
1193                 pool.append(
1194                     RemoteWorkerRecord(
1195                         username='scott',
1196                         machine='backup.house',
1197                         weight=8,
1198                         count=2,
1199                     ),
1200                 )
1201
1202             # The controller machine has a lot to do; go easy on it.
1203             for record in pool:
1204                 if record.machine == platform.node() and record.count > 1:
1205                     logger.info(f'Reducing workload for {record.machine}.')
1206                     record.count = 1
1207
1208             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1209             policy.register_worker_pool(pool)
1210             self.remote_executor = RemoteExecutor(pool, policy)
1211         return self.remote_executor
1212
1213     def shutdown(self) -> None:
1214         if self.thread_executor is not None:
1215             self.thread_executor.shutdown()
1216             self.thread_executor = None
1217         if self.process_executor is not None:
1218             self.process_executor.shutdown()
1219             self.process_executor = None
1220         if self.remote_executor is not None:
1221             self.remote_executor.shutdown()
1222             self.remote_executor = None