Adds a shutdown_if_idle method to the executors and also eliminates
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import concurrent.futures as fut
7 from collections import defaultdict
8 from dataclasses import dataclass
9 import logging
10 import numpy
11 import os
12 import platform
13 import random
14 import subprocess
15 import threading
16 import time
17 from typing import Any, Callable, Dict, List, Optional, Set
18 import warnings
19
20 import cloudpickle  # type: ignore
21 from overrides import overrides
22
23 from ansi import bg, fg, underline, reset
24 import argparse_utils
25 import config
26 from decorator_utils import singleton
27 from exec_utils import run_silently, cmd_in_background, cmd_with_timeout
28 import histogram as hist
29 from thread_utils import background_thread
30
31
32 logger = logging.getLogger(__name__)
33
34 parser = config.add_commandline_args(
35     f"Executors ({__file__})", "Args related to processing executors."
36 )
37 parser.add_argument(
38     '--executors_threadpool_size',
39     type=int,
40     metavar='#THREADS',
41     help='Number of threads in the default threadpool, leave unset for default',
42     default=None,
43 )
44 parser.add_argument(
45     '--executors_processpool_size',
46     type=int,
47     metavar='#PROCESSES',
48     help='Number of processes in the default processpool, leave unset for default',
49     default=None,
50 )
51 parser.add_argument(
52     '--executors_schedule_remote_backups',
53     default=True,
54     action=argparse_utils.ActionNoYes,
55     help='Should we schedule duplicative backup work if a remote bundle is slow',
56 )
57 parser.add_argument(
58     '--executors_max_bundle_failures',
59     type=int,
60     default=3,
61     metavar='#FAILURES',
62     help='Maximum number of failures before giving up on a bundle',
63 )
64
65 SSH = '/usr/bin/ssh -oForwardX11=no'
66 SCP = '/usr/bin/scp -C'
67
68
69 def make_cloud_pickle(fun, *args, **kwargs):
70     logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
71     return cloudpickle.dumps((fun, args, kwargs))
72
73
74 class BaseExecutor(ABC):
75     def __init__(self, *, title=''):
76         self.title = title
77         self.histogram = hist.SimpleHistogram(
78             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
79         )
80         self.task_count = 0
81
82     @abstractmethod
83     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
84         pass
85
86     @abstractmethod
87     def shutdown(self, wait: bool = True) -> None:
88         pass
89
90     def shutdown_if_idle(self) -> bool:
91         """Shutdown the executor and return True if the executor is idle
92         (i.e. there are no pending or active tasks).  Return False
93         otherwise.  Note: this should only be called by the launcher
94         process.
95
96         """
97         if self.task_count == 0:
98             self.shutdown()
99             return True
100         return False
101
102     def adjust_task_count(self, delta: int) -> None:
103         """Change the task count.  Note: do not call this method from a
104         worker, it should only be called by the launcher process /
105         thread / machine.
106
107         """
108         self.task_count += delta
109         logger.debug(f'Adjusted task count by {delta} to {self.task_count}')
110
111     def get_task_count(self) -> int:
112         """Change the task count.  Note: do not call this method from a
113         worker, it should only be called by the launcher process /
114         thread / machine.
115
116         """
117         return self.task_count
118
119
120 class ThreadExecutor(BaseExecutor):
121     def __init__(self, max_workers: Optional[int] = None):
122         super().__init__()
123         workers = None
124         if max_workers is not None:
125             workers = max_workers
126         elif 'executors_threadpool_size' in config.config:
127             workers = config.config['executors_threadpool_size']
128         logger.debug(f'Creating threadpool executor with {workers} workers')
129         self._thread_pool_executor = fut.ThreadPoolExecutor(
130             max_workers=workers, thread_name_prefix="thread_executor_helper"
131         )
132         self.already_shutdown = False
133
134     # This is run on a different thread; do not adjust task count here.
135     def run_local_bundle(self, fun, *args, **kwargs):
136         logger.debug(f"Running local bundle at {fun.__name__}")
137         result = fun(*args, **kwargs)
138         return result
139
140     @overrides
141     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
142         if self.already_shutdown:
143             raise Exception('Submitted work after shutdown.')
144         self.adjust_task_count(+1)
145         newargs = []
146         newargs.append(function)
147         for arg in args:
148             newargs.append(arg)
149         start = time.time()
150         result = self._thread_pool_executor.submit(
151             self.run_local_bundle, *newargs, **kwargs
152         )
153         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
154         result.add_done_callback(lambda _: self.adjust_task_count(-1))
155
156     @overrides
157     def shutdown(self, wait=True) -> None:
158         if not self.already_shutdown:
159             logger.debug(f'Shutting down threadpool executor {self.title}')
160             print(self.histogram)
161             self._thread_pool_executor.shutdown(wait)
162             self.already_shutdown = True
163
164
165 class ProcessExecutor(BaseExecutor):
166     def __init__(self, max_workers=None):
167         super().__init__()
168         workers = None
169         if max_workers is not None:
170             workers = max_workers
171         elif 'executors_processpool_size' in config.config:
172             workers = config.config['executors_processpool_size']
173         logger.debug(f'Creating processpool executor with {workers} workers.')
174         self._process_executor = fut.ProcessPoolExecutor(
175             max_workers=workers,
176         )
177         self.already_shutdown = False
178
179     # This is run in another process; do not adjust task count here.
180     def run_cloud_pickle(self, pickle):
181         fun, args, kwargs = cloudpickle.loads(pickle)
182         logger.debug(f"Running pickled bundle at {fun.__name__}")
183         result = fun(*args, **kwargs)
184         return result
185
186     @overrides
187     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
188         if self.already_shutdown:
189             raise Exception('Submitted work after shutdown.')
190         start = time.time()
191         self.adjust_task_count(+1)
192         pickle = make_cloud_pickle(function, *args, **kwargs)
193         result = self._process_executor.submit(self.run_cloud_pickle, pickle)
194         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
195         result.add_done_callback(lambda _: self.adjust_task_count(-1))
196         return result
197
198     @overrides
199     def shutdown(self, wait=True) -> None:
200         if not self.already_shutdown:
201             logger.debug(f'Shutting down processpool executor {self.title}')
202             self._process_executor.shutdown(wait)
203             print(self.histogram)
204             self.already_shutdown = True
205
206     def __getstate__(self):
207         state = self.__dict__.copy()
208         state['_process_executor'] = None
209         return state
210
211
212 class RemoteExecutorException(Exception):
213     """Thrown when a bundle cannot be executed despite several retries."""
214
215     pass
216
217
218 @dataclass
219 class RemoteWorkerRecord:
220     username: str
221     machine: str
222     weight: int
223     count: int
224
225     def __hash__(self):
226         return hash((self.username, self.machine))
227
228     def __repr__(self):
229         return f'{self.username}@{self.machine}'
230
231
232 @dataclass
233 class BundleDetails:
234     pickled_code: bytes
235     uuid: str
236     fname: str
237     worker: Optional[RemoteWorkerRecord]
238     username: Optional[str]
239     machine: Optional[str]
240     hostname: str
241     code_file: str
242     result_file: str
243     pid: int
244     start_ts: float
245     end_ts: float
246     slower_than_local_p95: bool
247     slower_than_global_p95: bool
248     src_bundle: BundleDetails
249     is_cancelled: threading.Event
250     was_cancelled: bool
251     backup_bundles: Optional[List[BundleDetails]]
252     failure_count: int
253
254     def __repr__(self):
255         uuid = self.uuid
256         if uuid[-9:-2] == '_backup':
257             uuid = uuid[:-9]
258             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
259         else:
260             suffix = uuid[-6:]
261
262         colorz = [
263             fg('violet red'),
264             fg('red'),
265             fg('orange'),
266             fg('peach orange'),
267             fg('yellow'),
268             fg('marigold yellow'),
269             fg('green yellow'),
270             fg('tea green'),
271             fg('cornflower blue'),
272             fg('turquoise blue'),
273             fg('tropical blue'),
274             fg('lavender purple'),
275             fg('medium purple'),
276         ]
277         c = colorz[int(uuid[-2:], 16) % len(colorz)]
278         fname = self.fname if self.fname is not None else 'nofname'
279         machine = self.machine if self.machine is not None else 'nomachine'
280         return f'{c}{suffix}/{fname}/{machine}{reset()}'
281
282
283 class RemoteExecutorStatus:
284     def __init__(self, total_worker_count: int) -> None:
285         self.worker_count: int = total_worker_count
286         self.known_workers: Set[RemoteWorkerRecord] = set()
287         self.start_time: float = time.time()
288         self.start_per_bundle: Dict[str, float] = defaultdict(float)
289         self.end_per_bundle: Dict[str, float] = defaultdict(float)
290         self.finished_bundle_timings_per_worker: Dict[
291             RemoteWorkerRecord, List[float]
292         ] = {}
293         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
294         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
295         self.finished_bundle_timings: List[float] = []
296         self.last_periodic_dump: Optional[float] = None
297         self.total_bundles_submitted: int = 0
298
299         # Protects reads and modification using self.  Also used
300         # as a memory fence for modifications to bundle.
301         self.lock: threading.Lock = threading.Lock()
302
303     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
304         with self.lock:
305             self.record_acquire_worker_already_locked(worker, uuid)
306
307     def record_acquire_worker_already_locked(
308         self, worker: RemoteWorkerRecord, uuid: str
309     ) -> None:
310         assert self.lock.locked()
311         self.known_workers.add(worker)
312         self.start_per_bundle[uuid] = None
313         x = self.in_flight_bundles_by_worker.get(worker, set())
314         x.add(uuid)
315         self.in_flight_bundles_by_worker[worker] = x
316
317     def record_bundle_details(self, details: BundleDetails) -> None:
318         with self.lock:
319             self.record_bundle_details_already_locked(details)
320
321     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
322         assert self.lock.locked()
323         self.bundle_details_by_uuid[details.uuid] = details
324
325     def record_release_worker(
326         self,
327         worker: RemoteWorkerRecord,
328         uuid: str,
329         was_cancelled: bool,
330     ) -> None:
331         with self.lock:
332             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
333
334     def record_release_worker_already_locked(
335         self,
336         worker: RemoteWorkerRecord,
337         uuid: str,
338         was_cancelled: bool,
339     ) -> None:
340         assert self.lock.locked()
341         ts = time.time()
342         self.end_per_bundle[uuid] = ts
343         self.in_flight_bundles_by_worker[worker].remove(uuid)
344         if not was_cancelled:
345             bundle_latency = ts - self.start_per_bundle[uuid]
346             x = self.finished_bundle_timings_per_worker.get(worker, list())
347             x.append(bundle_latency)
348             self.finished_bundle_timings_per_worker[worker] = x
349             self.finished_bundle_timings.append(bundle_latency)
350
351     def record_processing_began(self, uuid: str):
352         with self.lock:
353             self.start_per_bundle[uuid] = time.time()
354
355     def total_in_flight(self) -> int:
356         assert self.lock.locked()
357         total_in_flight = 0
358         for worker in self.known_workers:
359             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
360         return total_in_flight
361
362     def total_idle(self) -> int:
363         assert self.lock.locked()
364         return self.worker_count - self.total_in_flight()
365
366     def __repr__(self):
367         assert self.lock.locked()
368         ts = time.time()
369         total_finished = len(self.finished_bundle_timings)
370         total_in_flight = self.total_in_flight()
371         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
372         qall = None
373         if len(self.finished_bundle_timings) > 1:
374             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
375             ret += (
376                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, total={ts-self.start_time:.1f}s, '
377                 f'✅={total_finished}/{self.total_bundles_submitted}, '
378                 f'💻n={total_in_flight}/{self.worker_count}\n'
379             )
380         else:
381             ret += (
382                 f'⏱={ts-self.start_time:.1f}s, '
383                 f'✅={total_finished}/{self.total_bundles_submitted}, '
384                 f'💻n={total_in_flight}/{self.worker_count}\n'
385             )
386
387         for worker in self.known_workers:
388             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
389             timings = self.finished_bundle_timings_per_worker.get(worker, [])
390             count = len(timings)
391             qworker = None
392             if count > 1:
393                 qworker = numpy.quantile(timings, [0.5, 0.95])
394                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
395             else:
396                 ret += '\n'
397             if count > 0:
398                 ret += f'    ...finished {count} total bundle(s) so far\n'
399             in_flight = len(self.in_flight_bundles_by_worker[worker])
400             if in_flight > 0:
401                 ret += f'    ...{in_flight} bundles currently in flight:\n'
402                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
403                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
404                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
405                     if self.start_per_bundle[bundle_uuid] is not None:
406                         sec = ts - self.start_per_bundle[bundle_uuid]
407                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
408                     else:
409                         ret += f'       {details} setting up / copying data...'
410                         sec = 0.0
411
412                     if qworker is not None:
413                         if sec > qworker[1]:
414                             ret += f'{bg("red")}>💻p95{reset()} '
415                             if details is not None:
416                                 details.slower_than_local_p95 = True
417                         else:
418                             if details is not None:
419                                 details.slower_than_local_p95 = False
420
421                     if qall is not None:
422                         if sec > qall[1]:
423                             ret += f'{bg("red")}>∀p95{reset()} '
424                             if details is not None:
425                                 details.slower_than_global_p95 = True
426                         else:
427                             details.slower_than_global_p95 = False
428                     ret += '\n'
429         return ret
430
431     def periodic_dump(self, total_bundles_submitted: int) -> None:
432         assert self.lock.locked()
433         self.total_bundles_submitted = total_bundles_submitted
434         ts = time.time()
435         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
436             print(self)
437             self.last_periodic_dump = ts
438
439
440 class RemoteWorkerSelectionPolicy(ABC):
441     def register_worker_pool(self, workers):
442         self.workers = workers
443
444     @abstractmethod
445     def is_worker_available(self) -> bool:
446         pass
447
448     @abstractmethod
449     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
450         pass
451
452
453 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
454     @overrides
455     def is_worker_available(self) -> bool:
456         for worker in self.workers:
457             if worker.count > 0:
458                 return True
459         return False
460
461     @overrides
462     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
463         grabbag = []
464         for worker in self.workers:
465             if worker.machine != machine_to_avoid:
466                 if worker.count > 0:
467                     for _ in range(worker.count * worker.weight):
468                         grabbag.append(worker)
469
470         if len(grabbag) == 0:
471             logger.debug(
472                 f'There are no available workers that avoid {machine_to_avoid}...'
473             )
474             for worker in self.workers:
475                 if worker.count > 0:
476                     for _ in range(worker.count * worker.weight):
477                         grabbag.append(worker)
478
479         if len(grabbag) == 0:
480             logger.warning('There are no available workers?!')
481             return None
482
483         worker = random.sample(grabbag, 1)[0]
484         assert worker.count > 0
485         worker.count -= 1
486         logger.debug(f'Chose worker {worker}')
487         return worker
488
489
490 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
491     def __init__(self) -> None:
492         self.index = 0
493
494     @overrides
495     def is_worker_available(self) -> bool:
496         for worker in self.workers:
497             if worker.count > 0:
498                 return True
499         return False
500
501     @overrides
502     def acquire_worker(
503         self, machine_to_avoid: str = None
504     ) -> Optional[RemoteWorkerRecord]:
505         x = self.index
506         while True:
507             worker = self.workers[x]
508             if worker.count > 0:
509                 worker.count -= 1
510                 x += 1
511                 if x >= len(self.workers):
512                     x = 0
513                 self.index = x
514                 logger.debug(f'Selected worker {worker}')
515                 return worker
516             x += 1
517             if x >= len(self.workers):
518                 x = 0
519             if x == self.index:
520                 msg = 'Unexpectedly could not find a worker, retrying...'
521                 logger.warning(msg)
522                 return None
523
524
525 class RemoteExecutor(BaseExecutor):
526     def __init__(
527         self,
528         workers: List[RemoteWorkerRecord],
529         policy: RemoteWorkerSelectionPolicy,
530     ) -> None:
531         super().__init__()
532         self.workers = workers
533         self.policy = policy
534         self.worker_count = 0
535         for worker in self.workers:
536             self.worker_count += worker.count
537         if self.worker_count <= 0:
538             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
539             logger.critical(msg)
540             raise RemoteExecutorException(msg)
541         self.policy.register_worker_pool(self.workers)
542         self.cv = threading.Condition()
543         logger.debug(
544             f'Creating {self.worker_count} local threads, one per remote worker.'
545         )
546         self._helper_executor = fut.ThreadPoolExecutor(
547             thread_name_prefix="remote_executor_helper",
548             max_workers=self.worker_count,
549         )
550         self.status = RemoteExecutorStatus(self.worker_count)
551         self.total_bundles_submitted = 0
552         self.backup_lock = threading.Lock()
553         self.last_backup = None
554         (
555             self.heartbeat_thread,
556             self.heartbeat_stop_event,
557         ) = self.run_periodic_heartbeat()
558         self.already_shutdown = False
559
560     @background_thread
561     def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
562         while not stop_event.is_set():
563             time.sleep(5.0)
564             logger.debug('Running periodic heartbeat code...')
565             self.heartbeat()
566         logger.debug('Periodic heartbeat thread shutting down.')
567
568     def heartbeat(self) -> None:
569         # Note: this is invoked on a background thread, not an
570         # executor thread.  Be careful what you do with it b/c it
571         # needs to get back and dump status again periodically.
572         with self.status.lock:
573             self.status.periodic_dump(self.total_bundles_submitted)
574
575             # Look for bundles to reschedule via executor.submit
576             if config.config['executors_schedule_remote_backups']:
577                 self.maybe_schedule_backup_bundles()
578
579     def maybe_schedule_backup_bundles(self):
580         assert self.status.lock.locked()
581         num_done = len(self.status.finished_bundle_timings)
582         num_idle_workers = self.worker_count - self.task_count
583         now = time.time()
584         if (
585             num_done > 2
586             and num_idle_workers > 1
587             and (self.last_backup is None or (now - self.last_backup > 9.0))
588             and self.backup_lock.acquire(blocking=False)
589         ):
590             try:
591                 assert self.backup_lock.locked()
592
593                 bundle_to_backup = None
594                 best_score = None
595                 for (
596                     worker,
597                     bundle_uuids,
598                 ) in self.status.in_flight_bundles_by_worker.items():
599
600                     # Prefer to schedule backups of bundles running on
601                     # slower machines.
602                     base_score = 0
603                     for record in self.workers:
604                         if worker.machine == record.machine:
605                             base_score = float(record.weight)
606                             base_score = 1.0 / base_score
607                             base_score *= 200.0
608                             base_score = int(base_score)
609                             break
610
611                     for uuid in bundle_uuids:
612                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
613                         if (
614                             bundle is not None
615                             and bundle.src_bundle is None
616                             and bundle.backup_bundles is not None
617                         ):
618                             score = base_score
619
620                             # Schedule backups of bundles running
621                             # longer; especially those that are
622                             # unexpectedly slow.
623                             start_ts = self.status.start_per_bundle[uuid]
624                             if start_ts is not None:
625                                 runtime = now - start_ts
626                                 score += runtime
627                                 logger.debug(
628                                     f'score[{bundle}] => {score}  # latency boost'
629                                 )
630
631                                 if bundle.slower_than_local_p95:
632                                     score += runtime / 2
633                                     logger.debug(
634                                         f'score[{bundle}] => {score}  # >worker p95'
635                                     )
636
637                                 if bundle.slower_than_global_p95:
638                                     score += runtime / 4
639                                     logger.debug(
640                                         f'score[{bundle}] => {score}  # >global p95'
641                                     )
642
643                             # Prefer backups of bundles that don't
644                             # have backups already.
645                             backup_count = len(bundle.backup_bundles)
646                             if backup_count == 0:
647                                 score *= 2
648                             elif backup_count == 1:
649                                 score /= 2
650                             elif backup_count == 2:
651                                 score /= 8
652                             else:
653                                 score = 0
654                             logger.debug(
655                                 f'score[{bundle}] => {score}  # {backup_count} dup backup factor'
656                             )
657
658                             if score != 0 and (
659                                 best_score is None or score > best_score
660                             ):
661                                 bundle_to_backup = bundle
662                                 assert bundle is not None
663                                 assert bundle.backup_bundles is not None
664                                 assert bundle.src_bundle is None
665                                 best_score = score
666
667                 # Note: this is all still happening on the heartbeat
668                 # runner thread.  That's ok because
669                 # schedule_backup_for_bundle uses the executor to
670                 # submit the bundle again which will cause it to be
671                 # picked up by a worker thread and allow this thread
672                 # to return to run future heartbeats.
673                 if bundle_to_backup is not None:
674                     self.last_backup = now
675                     logger.info(
676                         f'=====> SCHEDULING BACKUP {bundle_to_backup} (score={best_score:.1f}) <====='
677                     )
678                     self.schedule_backup_for_bundle(bundle_to_backup)
679             finally:
680                 self.backup_lock.release()
681
682     def is_worker_available(self) -> bool:
683         return self.policy.is_worker_available()
684
685     def acquire_worker(
686         self, machine_to_avoid: str = None
687     ) -> Optional[RemoteWorkerRecord]:
688         return self.policy.acquire_worker(machine_to_avoid)
689
690     def find_available_worker_or_block(
691         self, machine_to_avoid: str = None
692     ) -> RemoteWorkerRecord:
693         with self.cv:
694             while not self.is_worker_available():
695                 self.cv.wait()
696             worker = self.acquire_worker(machine_to_avoid)
697             if worker is not None:
698                 return worker
699         msg = "We should never reach this point in the code"
700         logger.critical(msg)
701         raise Exception(msg)
702
703     def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
704         worker = bundle.worker
705         assert worker is not None
706         logger.debug(f'Released worker {worker}')
707         self.status.record_release_worker(
708             worker,
709             bundle.uuid,
710             was_cancelled,
711         )
712         with self.cv:
713             worker.count += 1
714             self.cv.notify()
715         self.adjust_task_count(-1)
716
717     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
718         with self.status.lock:
719             if bundle.is_cancelled.wait(timeout=0.0):
720                 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
721                 bundle.was_cancelled = True
722                 return True
723         return False
724
725     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
726         """Find a worker for bundle or block until one is available."""
727         self.adjust_task_count(+1)
728         uuid = bundle.uuid
729         hostname = bundle.hostname
730         avoid_machine = override_avoid_machine
731         is_original = bundle.src_bundle is None
732
733         # Try not to schedule a backup on the same host as the original.
734         if avoid_machine is None and bundle.src_bundle is not None:
735             avoid_machine = bundle.src_bundle.machine
736         worker = None
737         while worker is None:
738             worker = self.find_available_worker_or_block(avoid_machine)
739         assert worker
740
741         # Ok, found a worker.
742         bundle.worker = worker
743         machine = bundle.machine = worker.machine
744         username = bundle.username = worker.username
745         self.status.record_acquire_worker(worker, uuid)
746         logger.debug(f'{bundle}: Running bundle on {worker}...')
747
748         # Before we do any work, make sure the bundle is still viable.
749         # It may have been some time between when it was submitted and
750         # now due to lack of worker availability and someone else may
751         # have already finished it.
752         if self.check_if_cancelled(bundle):
753             try:
754                 return self.process_work_result(bundle)
755             except Exception as e:
756                 logger.warning(
757                     f'{bundle}: bundle says it\'s cancelled upfront but no results?!'
758                 )
759                 self.release_worker(bundle)
760                 if is_original:
761                     # Weird.  We are the original owner of this
762                     # bundle.  For it to have been cancelled, a backup
763                     # must have already started and completed before
764                     # we even for started.  Moreover, the backup says
765                     # it is done but we can't find the results it
766                     # should have copied over.  Reschedule the whole
767                     # thing.
768                     logger.exception(e)
769                     logger.error(
770                         f'{bundle}: We are the original owner thread and yet there are '
771                         + 'no results for this bundle.  This is unexpected and bad.'
772                     )
773                     return self.emergency_retry_nasty_bundle(bundle)
774                 else:
775                     # Expected(?).  We're a backup and our bundle is
776                     # cancelled before we even got started.  Something
777                     # went bad in process_work_result (I acutually don't
778                     # see what?) but probably not worth worrying
779                     # about.  Let the original thread worry about
780                     # either finding the results or complaining about
781                     # it.
782                     return None
783
784         # Send input code / data to worker machine if it's not local.
785         if hostname not in machine:
786             try:
787                 cmd = (
788                     f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
789                 )
790                 start_ts = time.time()
791                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
792                 run_silently(cmd)
793                 xfer_latency = time.time() - start_ts
794                 logger.debug(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.")
795             except Exception as e:
796                 self.release_worker(bundle)
797                 if is_original:
798                     # Weird.  We tried to copy the code to the worker and it failed...
799                     # And we're the original bundle.  We have to retry.
800                     logger.exception(e)
801                     logger.error(
802                         f"{bundle}: Failed to send instructions to the worker machine?! "
803                         + "This is not expected; we\'re the original bundle so this shouldn\'t "
804                         + "be a race condition.  Attempting an emergency retry..."
805                     )
806                     return self.emergency_retry_nasty_bundle(bundle)
807                 else:
808                     # This is actually expected; we're a backup.
809                     # There's a race condition where someone else
810                     # already finished the work and removed the source
811                     # code file before we could copy it.  No biggie.
812                     msg = f'{bundle}: Failed to send instructions to the worker machine... '
813                     msg += 'We\'re a backup and this may be caused by the original (or some '
814                     msg += 'other backup) already finishing this work.  Ignoring this.'
815                     logger.warning(msg)
816                     return None
817
818         # Kick off the work.  Note that if this fails we let
819         # wait_for_process deal with it.
820         self.status.record_processing_began(uuid)
821         cmd = (
822             f'{SSH} {bundle.username}@{bundle.machine} '
823             f'"source py38-venv/bin/activate &&'
824             f' /home/scott/lib/python_modules/remote_worker.py'
825             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
826         )
827         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
828         p = cmd_in_background(cmd, silent=True)
829         bundle.pid = p.pid
830         logger.debug(
831             f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.'
832         )
833         return self.wait_for_process(p, bundle, 0)
834
835     def wait_for_process(
836         self, p: subprocess.Popen, bundle: BundleDetails, depth: int
837     ) -> Any:
838         machine = bundle.machine
839         pid = p.pid
840         if depth > 3:
841             logger.error(
842                 f"I've gotten repeated errors waiting on this bundle; giving up on pid={pid}."
843             )
844             p.terminate()
845             self.release_worker(bundle)
846             return self.emergency_retry_nasty_bundle(bundle)
847
848         # Spin until either the ssh job we scheduled finishes the
849         # bundle or some backup worker signals that they finished it
850         # before we could.
851         while True:
852             try:
853                 p.wait(timeout=0.25)
854             except subprocess.TimeoutExpired:
855                 if self.check_if_cancelled(bundle):
856                     logger.info(
857                         f'{bundle}: looks like another worker finished bundle...'
858                     )
859                     break
860             else:
861                 logger.info(f"{bundle}: pid {pid} ({machine}) is finished!")
862                 p = None
863                 break
864
865         # If we get here we believe the bundle is done; either the ssh
866         # subprocess finished (hopefully successfully) or we noticed
867         # that some other worker seems to have completed the bundle
868         # and we're bailing out.
869         try:
870             ret = self.process_work_result(bundle)
871             if ret is not None and p is not None:
872                 p.terminate()
873             return ret
874
875         # Something went wrong; e.g. we could not copy the results
876         # back, cleanup after ourselves on the remote machine, or
877         # unpickle the results we got from the remove machine.  If we
878         # still have an active ssh subprocess, keep waiting on it.
879         # Otherwise, time for an emergency reschedule.
880         except Exception as e:
881             logger.exception(e)
882             logger.error(f'{bundle}: Something unexpected just happened...')
883             if p is not None:
884                 msg = f"{bundle}: Failed to wrap up \"done\" bundle, re-waiting on active ssh."
885                 logger.warning(msg)
886                 return self.wait_for_process(p, bundle, depth + 1)
887             else:
888                 self.release_worker(bundle)
889                 return self.emergency_retry_nasty_bundle(bundle)
890
891     def process_work_result(self, bundle: BundleDetails) -> Any:
892         with self.status.lock:
893             is_original = bundle.src_bundle is None
894             was_cancelled = bundle.was_cancelled
895             username = bundle.username
896             machine = bundle.machine
897             result_file = bundle.result_file
898             code_file = bundle.code_file
899
900             # Whether original or backup, if we finished first we must
901             # fetch the results if the computation happened on a
902             # remote machine.
903             bundle.end_ts = time.time()
904             if not was_cancelled:
905                 assert bundle.machine is not None
906                 if bundle.hostname not in bundle.machine:
907                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
908                     logger.info(
909                         f"{bundle}: Fetching results back from {username}@{machine} via {cmd}"
910                     )
911
912                     # If either of these throw they are handled in
913                     # wait_for_process.
914                     attempts = 0
915                     while True:
916                         try:
917                             run_silently(cmd)
918                         except Exception as e:
919                             attempts += 1
920                             if attempts >= 3:
921                                 raise (e)
922                         else:
923                             break
924
925                     run_silently(
926                         f'{SSH} {username}@{machine}'
927                         f' "/bin/rm -f {code_file} {result_file}"'
928                     )
929                     logger.debug(
930                         f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.'
931                     )
932                 dur = bundle.end_ts - bundle.start_ts
933                 self.histogram.add_item(dur)
934
935         # Only the original worker should unpickle the file contents
936         # though since it's the only one whose result matters.  The
937         # original is also the only job that may delete result_file
938         # from disk.  Note that the original may have been cancelled
939         # if one of the backups finished first; it still must read the
940         # result from disk.
941         if is_original:
942             logger.debug(f"{bundle}: Unpickling {result_file}.")
943             try:
944                 with open(result_file, 'rb') as rb:
945                     serialized = rb.read()
946                 result = cloudpickle.loads(serialized)
947             except Exception as e:
948                 logger.exception(e)
949                 msg = f'Failed to load {result_file}... this is bad news.'
950                 logger.critical(msg)
951                 self.release_worker(bundle)
952
953                 # Re-raise the exception; the code in wait_for_process may
954                 # decide to emergency_retry_nasty_bundle here.
955                 raise Exception(e)
956             logger.debug(f'Removing local (master) {code_file} and {result_file}.')
957             os.remove(f'{result_file}')
958             os.remove(f'{code_file}')
959
960             # Notify any backups that the original is done so they
961             # should stop ASAP.  Do this whether or not we
962             # finished first since there could be more than one
963             # backup.
964             if bundle.backup_bundles is not None:
965                 for backup in bundle.backup_bundles:
966                     logger.debug(
967                         f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled'
968                     )
969                     backup.is_cancelled.set()
970
971         # This is a backup job and, by now, we have already fetched
972         # the bundle results.
973         else:
974             # Backup results don't matter, they just need to leave the
975             # result file in the right place for their originals to
976             # read/unpickle later.
977             result = None
978
979             # Tell the original to stop if we finished first.
980             if not was_cancelled:
981                 logger.debug(
982                     f'{bundle}: Notifying original {bundle.src_bundle.uuid} we beat them to it.'
983                 )
984                 bundle.src_bundle.is_cancelled.set()
985         self.release_worker(bundle, was_cancelled=was_cancelled)
986         return result
987
988     def create_original_bundle(self, pickle, fname: str):
989         from string_utils import generate_uuid
990
991         uuid = generate_uuid(omit_dashes=True)
992         code_file = f'/tmp/{uuid}.code.bin'
993         result_file = f'/tmp/{uuid}.result.bin'
994
995         logger.debug(f'Writing pickled code to {code_file}')
996         with open(f'{code_file}', 'wb') as wb:
997             wb.write(pickle)
998
999         bundle = BundleDetails(
1000             pickled_code=pickle,
1001             uuid=uuid,
1002             fname=fname,
1003             worker=None,
1004             username=None,
1005             machine=None,
1006             hostname=platform.node(),
1007             code_file=code_file,
1008             result_file=result_file,
1009             pid=0,
1010             start_ts=time.time(),
1011             end_ts=0.0,
1012             slower_than_local_p95=False,
1013             slower_than_global_p95=False,
1014             src_bundle=None,
1015             is_cancelled=threading.Event(),
1016             was_cancelled=False,
1017             backup_bundles=[],
1018             failure_count=0,
1019         )
1020         self.status.record_bundle_details(bundle)
1021         logger.debug(f'{bundle}: Created an original bundle')
1022         return bundle
1023
1024     def create_backup_bundle(self, src_bundle: BundleDetails):
1025         assert src_bundle.backup_bundles is not None
1026         n = len(src_bundle.backup_bundles)
1027         uuid = src_bundle.uuid + f'_backup#{n}'
1028
1029         backup_bundle = BundleDetails(
1030             pickled_code=src_bundle.pickled_code,
1031             uuid=uuid,
1032             fname=src_bundle.fname,
1033             worker=None,
1034             username=None,
1035             machine=None,
1036             hostname=src_bundle.hostname,
1037             code_file=src_bundle.code_file,
1038             result_file=src_bundle.result_file,
1039             pid=0,
1040             start_ts=time.time(),
1041             end_ts=0.0,
1042             slower_than_local_p95=False,
1043             slower_than_global_p95=False,
1044             src_bundle=src_bundle,
1045             is_cancelled=threading.Event(),
1046             was_cancelled=False,
1047             backup_bundles=None,  # backup backups not allowed
1048             failure_count=0,
1049         )
1050         src_bundle.backup_bundles.append(backup_bundle)
1051         self.status.record_bundle_details_already_locked(backup_bundle)
1052         logger.debug(f'{backup_bundle}: Created a backup bundle')
1053         return backup_bundle
1054
1055     def schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1056         assert self.status.lock.locked()
1057         assert src_bundle is not None
1058         backup_bundle = self.create_backup_bundle(src_bundle)
1059         logger.debug(
1060             f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
1061         )
1062         self._helper_executor.submit(self.launch, backup_bundle)
1063
1064         # Results from backups don't matter; if they finish first
1065         # they will move the result_file to this machine and let
1066         # the original pick them up and unpickle them.
1067
1068     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> fut.Future:
1069         is_original = bundle.src_bundle is None
1070         bundle.worker = None
1071         avoid_last_machine = bundle.machine
1072         bundle.machine = None
1073         bundle.username = None
1074         bundle.failure_count += 1
1075         if is_original:
1076             retry_limit = 3
1077         else:
1078             retry_limit = 2
1079
1080         if bundle.failure_count > retry_limit:
1081             logger.error(
1082                 f'{bundle}: Tried this bundle too many times already ({retry_limit}x); giving up.'
1083             )
1084             if is_original:
1085                 raise RemoteExecutorException(
1086                     f'{bundle}: This bundle can\'t be completed despite several backups and retries'
1087                 )
1088             else:
1089                 logger.error(
1090                     f'{bundle}: At least it\'s only a backup; better luck with the others.'
1091                 )
1092             return None
1093         else:
1094             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1095             logger.warning(msg)
1096             warnings.warn(msg)
1097             return self.launch(bundle, avoid_last_machine)
1098
1099     @overrides
1100     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1101         if self.already_shutdown:
1102             raise Exception('Submitted work after shutdown.')
1103         pickle = make_cloud_pickle(function, *args, **kwargs)
1104         bundle = self.create_original_bundle(pickle, function.__name__)
1105         self.total_bundles_submitted += 1
1106         return self._helper_executor.submit(self.launch, bundle)
1107
1108     @overrides
1109     def shutdown(self, wait=True) -> None:
1110         if not self.already_shutdown:
1111             logging.debug(f'Shutting down RemoteExecutor {self.title}')
1112             self.heartbeat_stop_event.set()
1113             self.heartbeat_thread.join()
1114             self._helper_executor.shutdown(wait)
1115             print(self.histogram)
1116             self.already_shutdown = True
1117
1118
1119 @singleton
1120 class DefaultExecutors(object):
1121     def __init__(self):
1122         self.thread_executor: Optional[ThreadExecutor] = None
1123         self.process_executor: Optional[ProcessExecutor] = None
1124         self.remote_executor: Optional[RemoteExecutor] = None
1125
1126     def ping(self, host) -> bool:
1127         logger.debug(f'RUN> ping -c 1 {host}')
1128         try:
1129             x = cmd_with_timeout(
1130                 f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0
1131             )
1132             return x == 0
1133         except Exception:
1134             return False
1135
1136     def thread_pool(self) -> ThreadExecutor:
1137         if self.thread_executor is None:
1138             self.thread_executor = ThreadExecutor()
1139         return self.thread_executor
1140
1141     def process_pool(self) -> ProcessExecutor:
1142         if self.process_executor is None:
1143             self.process_executor = ProcessExecutor()
1144         return self.process_executor
1145
1146     def remote_pool(self) -> RemoteExecutor:
1147         if self.remote_executor is None:
1148             logger.info('Looking for some helper machines...')
1149             pool: List[RemoteWorkerRecord] = []
1150             if self.ping('cheetah.house'):
1151                 logger.info('Found cheetah.house')
1152                 pool.append(
1153                     RemoteWorkerRecord(
1154                         username='scott',
1155                         machine='cheetah.house',
1156                         weight=30,
1157                         count=6,
1158                     ),
1159                 )
1160             if self.ping('meerkat.cabin'):
1161                 logger.info('Found meerkat.cabin')
1162                 pool.append(
1163                     RemoteWorkerRecord(
1164                         username='scott',
1165                         machine='meerkat.cabin',
1166                         weight=12,
1167                         count=2,
1168                     ),
1169                 )
1170             if self.ping('wannabe.house'):
1171                 logger.info('Found wannabe.house')
1172                 pool.append(
1173                     RemoteWorkerRecord(
1174                         username='scott',
1175                         machine='wannabe.house',
1176                         weight=25,
1177                         count=10,
1178                     ),
1179                 )
1180             if self.ping('puma.cabin'):
1181                 logger.info('Found puma.cabin')
1182                 pool.append(
1183                     RemoteWorkerRecord(
1184                         username='scott',
1185                         machine='puma.cabin',
1186                         weight=30,
1187                         count=6,
1188                     ),
1189                 )
1190             if self.ping('backup.house'):
1191                 logger.info('Found backup.house')
1192                 pool.append(
1193                     RemoteWorkerRecord(
1194                         username='scott',
1195                         machine='backup.house',
1196                         weight=8,
1197                         count=2,
1198                     ),
1199                 )
1200
1201             # The controller machine has a lot to do; go easy on it.
1202             for record in pool:
1203                 if record.machine == platform.node() and record.count > 1:
1204                     logger.info(f'Reducing workload for {record.machine}.')
1205                     record.count = 1
1206
1207             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1208             policy.register_worker_pool(pool)
1209             self.remote_executor = RemoteExecutor(pool, policy)
1210         return self.remote_executor
1211
1212     def shutdown(self) -> None:
1213         if self.thread_executor is not None:
1214             self.thread_executor.shutdown()
1215             self.thread_executor = None
1216         if self.process_executor is not None:
1217             self.process_executor.shutdown()
1218             self.process_executor = None
1219         if self.remote_executor is not None:
1220             self.remote_executor.shutdown()
1221             self.remote_executor = None