e092e1058bac0d7ba754aaefbfefcfe51f3ffcec
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import concurrent.futures as fut
7 from collections import defaultdict
8 from dataclasses import dataclass
9 import logging
10 import numpy
11 import os
12 import platform
13 import random
14 import subprocess
15 import threading
16 import time
17 from typing import Any, Callable, Dict, List, Optional, Set
18 import warnings
19
20 import cloudpickle  # type: ignore
21 from overrides import overrides
22
23 from ansi import bg, fg, underline, reset
24 import argparse_utils
25 import config
26 from decorator_utils import singleton
27 from exec_utils import run_silently, cmd_in_background, cmd_with_timeout
28 import histogram as hist
29 from thread_utils import background_thread
30
31
32 logger = logging.getLogger(__name__)
33
34 parser = config.add_commandline_args(
35     f"Executors ({__file__})", "Args related to processing executors."
36 )
37 parser.add_argument(
38     '--executors_threadpool_size',
39     type=int,
40     metavar='#THREADS',
41     help='Number of threads in the default threadpool, leave unset for default',
42     default=None,
43 )
44 parser.add_argument(
45     '--executors_processpool_size',
46     type=int,
47     metavar='#PROCESSES',
48     help='Number of processes in the default processpool, leave unset for default',
49     default=None,
50 )
51 parser.add_argument(
52     '--executors_schedule_remote_backups',
53     default=True,
54     action=argparse_utils.ActionNoYes,
55     help='Should we schedule duplicative backup work if a remote bundle is slow',
56 )
57 parser.add_argument(
58     '--executors_max_bundle_failures',
59     type=int,
60     default=3,
61     metavar='#FAILURES',
62     help='Maximum number of failures before giving up on a bundle',
63 )
64
65 SSH = '/usr/bin/ssh -oForwardX11=no'
66 SCP = '/usr/bin/scp'
67
68
69 def make_cloud_pickle(fun, *args, **kwargs):
70     logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
71     return cloudpickle.dumps((fun, args, kwargs))
72
73
74 class BaseExecutor(ABC):
75     def __init__(self, *, title=''):
76         self.title = title
77         self.task_count = 0
78         self.histogram = hist.SimpleHistogram(
79             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
80         )
81
82     @abstractmethod
83     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
84         pass
85
86     @abstractmethod
87     def shutdown(self, wait: bool = True) -> None:
88         pass
89
90     def adjust_task_count(self, delta: int) -> None:
91         self.task_count += delta
92         logger.debug(f'Executor current task count is {self.task_count}')
93
94
95 class ThreadExecutor(BaseExecutor):
96     def __init__(self, max_workers: Optional[int] = None):
97         super().__init__()
98         workers = None
99         if max_workers is not None:
100             workers = max_workers
101         elif 'executors_threadpool_size' in config.config:
102             workers = config.config['executors_threadpool_size']
103         logger.debug(f'Creating threadpool executor with {workers} workers')
104         self._thread_pool_executor = fut.ThreadPoolExecutor(
105             max_workers=workers, thread_name_prefix="thread_executor_helper"
106         )
107
108     def run_local_bundle(self, fun, *args, **kwargs):
109         logger.debug(f"Running local bundle at {fun.__name__}")
110         start = time.time()
111         result = fun(*args, **kwargs)
112         end = time.time()
113         self.adjust_task_count(-1)
114         duration = end - start
115         logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
116         self.histogram.add_item(duration)
117         return result
118
119     @overrides
120     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
121         self.adjust_task_count(+1)
122         newargs = []
123         newargs.append(function)
124         for arg in args:
125             newargs.append(arg)
126         return self._thread_pool_executor.submit(
127             self.run_local_bundle, *newargs, **kwargs
128         )
129
130     @overrides
131     def shutdown(self, wait=True) -> None:
132         logger.debug(f'Shutting down threadpool executor {self.title}')
133         print(self.histogram)
134         self._thread_pool_executor.shutdown(wait)
135
136
137 class ProcessExecutor(BaseExecutor):
138     def __init__(self, max_workers=None):
139         super().__init__()
140         workers = None
141         if max_workers is not None:
142             workers = max_workers
143         elif 'executors_processpool_size' in config.config:
144             workers = config.config['executors_processpool_size']
145         logger.debug(f'Creating processpool executor with {workers} workers.')
146         self._process_executor = fut.ProcessPoolExecutor(
147             max_workers=workers,
148         )
149
150     def run_cloud_pickle(self, pickle):
151         fun, args, kwargs = cloudpickle.loads(pickle)
152         logger.debug(f"Running pickled bundle at {fun.__name__}")
153         result = fun(*args, **kwargs)
154         self.adjust_task_count(-1)
155         return result
156
157     @overrides
158     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
159         start = time.time()
160         self.adjust_task_count(+1)
161         pickle = make_cloud_pickle(function, *args, **kwargs)
162         result = self._process_executor.submit(self.run_cloud_pickle, pickle)
163         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
164         return result
165
166     @overrides
167     def shutdown(self, wait=True) -> None:
168         logger.debug(f'Shutting down processpool executor {self.title}')
169         self._process_executor.shutdown(wait)
170         print(self.histogram)
171
172     def __getstate__(self):
173         state = self.__dict__.copy()
174         state['_process_executor'] = None
175         return state
176
177
178 class RemoteExecutorException(Exception):
179     """Thrown when a bundle cannot be executed despite several retries."""
180
181     pass
182
183
184 @dataclass
185 class RemoteWorkerRecord:
186     username: str
187     machine: str
188     weight: int
189     count: int
190
191     def __hash__(self):
192         return hash((self.username, self.machine))
193
194     def __repr__(self):
195         return f'{self.username}@{self.machine}'
196
197
198 @dataclass
199 class BundleDetails:
200     pickled_code: bytes
201     uuid: str
202     fname: str
203     worker: Optional[RemoteWorkerRecord]
204     username: Optional[str]
205     machine: Optional[str]
206     hostname: str
207     code_file: str
208     result_file: str
209     pid: int
210     start_ts: float
211     end_ts: float
212     slower_than_local_p95: bool
213     slower_than_global_p95: bool
214     src_bundle: BundleDetails
215     is_cancelled: threading.Event
216     was_cancelled: bool
217     backup_bundles: Optional[List[BundleDetails]]
218     failure_count: int
219
220     def __repr__(self):
221         uuid = self.uuid
222         if uuid[-9:-2] == '_backup':
223             uuid = uuid[:-9]
224             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
225         else:
226             suffix = uuid[-6:]
227
228         colorz = [
229             fg('violet red'),
230             fg('red'),
231             fg('orange'),
232             fg('peach orange'),
233             fg('yellow'),
234             fg('marigold yellow'),
235             fg('green yellow'),
236             fg('tea green'),
237             fg('cornflower blue'),
238             fg('turquoise blue'),
239             fg('tropical blue'),
240             fg('lavender purple'),
241             fg('medium purple'),
242         ]
243         c = colorz[int(uuid[-2:], 16) % len(colorz)]
244         fname = self.fname if self.fname is not None else 'nofname'
245         machine = self.machine if self.machine is not None else 'nomachine'
246         return f'{c}{suffix}/{fname}/{machine}{reset()}'
247
248
249 class RemoteExecutorStatus:
250     def __init__(self, total_worker_count: int) -> None:
251         self.worker_count = total_worker_count
252         self.known_workers: Set[RemoteWorkerRecord] = set()
253         self.start_per_bundle: Dict[str, float] = defaultdict(float)
254         self.end_per_bundle: Dict[str, float] = defaultdict(float)
255         self.finished_bundle_timings_per_worker: Dict[
256             RemoteWorkerRecord, List[float]
257         ] = {}
258         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
259         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
260         self.finished_bundle_timings: List[float] = []
261         self.last_periodic_dump: Optional[float] = None
262         self.total_bundles_submitted = 0
263
264         # Protects reads and modification using self.  Also used
265         # as a memory fence for modifications to bundle.
266         self.lock = threading.Lock()
267
268     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
269         with self.lock:
270             self.record_acquire_worker_already_locked(worker, uuid)
271
272     def record_acquire_worker_already_locked(
273         self, worker: RemoteWorkerRecord, uuid: str
274     ) -> None:
275         assert self.lock.locked()
276         self.known_workers.add(worker)
277         self.start_per_bundle[uuid] = None
278         x = self.in_flight_bundles_by_worker.get(worker, set())
279         x.add(uuid)
280         self.in_flight_bundles_by_worker[worker] = x
281
282     def record_bundle_details(self, details: BundleDetails) -> None:
283         with self.lock:
284             self.record_bundle_details_already_locked(details)
285
286     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
287         assert self.lock.locked()
288         self.bundle_details_by_uuid[details.uuid] = details
289
290     def record_release_worker(
291         self,
292         worker: RemoteWorkerRecord,
293         uuid: str,
294         was_cancelled: bool,
295     ) -> None:
296         with self.lock:
297             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
298
299     def record_release_worker_already_locked(
300         self,
301         worker: RemoteWorkerRecord,
302         uuid: str,
303         was_cancelled: bool,
304     ) -> None:
305         assert self.lock.locked()
306         ts = time.time()
307         self.end_per_bundle[uuid] = ts
308         self.in_flight_bundles_by_worker[worker].remove(uuid)
309         if not was_cancelled:
310             bundle_latency = ts - self.start_per_bundle[uuid]
311             x = self.finished_bundle_timings_per_worker.get(worker, list())
312             x.append(bundle_latency)
313             self.finished_bundle_timings_per_worker[worker] = x
314             self.finished_bundle_timings.append(bundle_latency)
315
316     def record_processing_began(self, uuid: str):
317         with self.lock:
318             self.start_per_bundle[uuid] = time.time()
319
320     def total_in_flight(self) -> int:
321         assert self.lock.locked()
322         total_in_flight = 0
323         for worker in self.known_workers:
324             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
325         return total_in_flight
326
327     def total_idle(self) -> int:
328         assert self.lock.locked()
329         return self.worker_count - self.total_in_flight()
330
331     def __repr__(self):
332         assert self.lock.locked()
333         ts = time.time()
334         total_finished = len(self.finished_bundle_timings)
335         total_in_flight = self.total_in_flight()
336         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
337         qall = None
338         if len(self.finished_bundle_timings) > 1:
339             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
340             ret += (
341                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, '
342                 f'✅={total_finished}/{self.total_bundles_submitted}, '
343                 f'💻n={total_in_flight}/{self.worker_count}\n'
344             )
345         else:
346             ret += (
347                 f' ✅={total_finished}/{self.total_bundles_submitted}, '
348                 f'💻n={total_in_flight}/{self.worker_count}\n'
349             )
350
351         for worker in self.known_workers:
352             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
353             timings = self.finished_bundle_timings_per_worker.get(worker, [])
354             count = len(timings)
355             qworker = None
356             if count > 1:
357                 qworker = numpy.quantile(timings, [0.5, 0.95])
358                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
359             else:
360                 ret += '\n'
361             if count > 0:
362                 ret += f'    ...finished {count} total bundle(s) so far\n'
363             in_flight = len(self.in_flight_bundles_by_worker[worker])
364             if in_flight > 0:
365                 ret += f'    ...{in_flight} bundles currently in flight:\n'
366                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
367                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
368                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
369                     if self.start_per_bundle[bundle_uuid] is not None:
370                         sec = ts - self.start_per_bundle[bundle_uuid]
371                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
372                     else:
373                         ret += f'       {details} setting up / copying data...'
374                         sec = 0.0
375
376                     if qworker is not None:
377                         if sec > qworker[1]:
378                             ret += f'{bg("red")}>💻p95{reset()} '
379                             if details is not None:
380                                 details.slower_than_local_p95 = True
381                         else:
382                             if details is not None:
383                                 details.slower_than_local_p95 = False
384
385                     if qall is not None:
386                         if sec > qall[1]:
387                             ret += f'{bg("red")}>∀p95{reset()} '
388                             if details is not None:
389                                 details.slower_than_global_p95 = True
390                         else:
391                             details.slower_than_global_p95 = False
392                     ret += '\n'
393         return ret
394
395     def periodic_dump(self, total_bundles_submitted: int) -> None:
396         assert self.lock.locked()
397         self.total_bundles_submitted = total_bundles_submitted
398         ts = time.time()
399         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
400             print(self)
401             self.last_periodic_dump = ts
402
403
404 class RemoteWorkerSelectionPolicy(ABC):
405     def register_worker_pool(self, workers):
406         self.workers = workers
407
408     @abstractmethod
409     def is_worker_available(self) -> bool:
410         pass
411
412     @abstractmethod
413     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
414         pass
415
416
417 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
418     @overrides
419     def is_worker_available(self) -> bool:
420         for worker in self.workers:
421             if worker.count > 0:
422                 return True
423         return False
424
425     @overrides
426     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
427         grabbag = []
428         for worker in self.workers:
429             for x in range(0, worker.count):
430                 for y in range(0, worker.weight):
431                     grabbag.append(worker)
432
433         for _ in range(0, 5):
434             random.shuffle(grabbag)
435             worker = grabbag[0]
436             if worker.machine != machine_to_avoid or _ > 2:
437                 if worker.count > 0:
438                     worker.count -= 1
439                     logger.debug(f'Selected worker {worker}')
440                     return worker
441         msg = 'Unexpectedly could not find a worker, retrying...'
442         logger.warning(msg)
443         return None
444
445
446 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
447     def __init__(self) -> None:
448         self.index = 0
449
450     @overrides
451     def is_worker_available(self) -> bool:
452         for worker in self.workers:
453             if worker.count > 0:
454                 return True
455         return False
456
457     @overrides
458     def acquire_worker(
459         self, machine_to_avoid: str = None
460     ) -> Optional[RemoteWorkerRecord]:
461         x = self.index
462         while True:
463             worker = self.workers[x]
464             if worker.count > 0:
465                 worker.count -= 1
466                 x += 1
467                 if x >= len(self.workers):
468                     x = 0
469                 self.index = x
470                 logger.debug(f'Selected worker {worker}')
471                 return worker
472             x += 1
473             if x >= len(self.workers):
474                 x = 0
475             if x == self.index:
476                 msg = 'Unexpectedly could not find a worker, retrying...'
477                 logger.warning(msg)
478                 return None
479
480
481 class RemoteExecutor(BaseExecutor):
482     def __init__(
483         self, workers: List[RemoteWorkerRecord], policy: RemoteWorkerSelectionPolicy
484     ) -> None:
485         super().__init__()
486         self.workers = workers
487         self.policy = policy
488         self.worker_count = 0
489         for worker in self.workers:
490             self.worker_count += worker.count
491         if self.worker_count <= 0:
492             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
493             logger.critical(msg)
494             raise RemoteExecutorException(msg)
495         self.policy.register_worker_pool(self.workers)
496         self.cv = threading.Condition()
497         logger.debug(
498             f'Creating {self.worker_count} local threads, one per remote worker.'
499         )
500         self._helper_executor = fut.ThreadPoolExecutor(
501             thread_name_prefix="remote_executor_helper",
502             max_workers=self.worker_count,
503         )
504         self.status = RemoteExecutorStatus(self.worker_count)
505         self.total_bundles_submitted = 0
506         self.backup_lock = threading.Lock()
507         self.last_backup = None
508         (
509             self.heartbeat_thread,
510             self.heartbeat_stop_event,
511         ) = self.run_periodic_heartbeat()
512
513     @background_thread
514     def run_periodic_heartbeat(self, stop_event) -> None:
515         while not stop_event.is_set():
516             time.sleep(5.0)
517             logger.debug('Running periodic heartbeat code...')
518             self.heartbeat()
519         logger.debug('Periodic heartbeat thread shutting down.')
520
521     def heartbeat(self) -> None:
522         with self.status.lock:
523             # Dump regular progress report
524             self.status.periodic_dump(self.total_bundles_submitted)
525
526             # Look for bundles to reschedule via executor.submit
527             if config.config['executors_schedule_remote_backups']:
528                 self.maybe_schedule_backup_bundles()
529
530     def maybe_schedule_backup_bundles(self):
531         assert self.status.lock.locked()
532         num_done = len(self.status.finished_bundle_timings)
533         num_idle_workers = self.worker_count - self.task_count
534         now = time.time()
535         if (
536             num_done > 2
537             and num_idle_workers > 1
538             and (self.last_backup is None or (now - self.last_backup > 6.0))
539             and self.backup_lock.acquire(blocking=False)
540         ):
541             try:
542                 assert self.backup_lock.locked()
543
544                 bundle_to_backup = None
545                 best_score = None
546                 for (
547                     worker,
548                     bundle_uuids,
549                 ) in self.status.in_flight_bundles_by_worker.items():
550
551                     # Prefer to schedule backups of bundles running on
552                     # slower machines.
553                     base_score = 0
554                     for record in self.workers:
555                         if worker.machine == record.machine:
556                             base_score = float(record.weight)
557                             base_score = 1.0 / base_score
558                             base_score *= 200.0
559                             base_score = int(base_score)
560                             break
561
562                     for uuid in bundle_uuids:
563                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
564                         if (
565                             bundle is not None
566                             and bundle.src_bundle is None
567                             and bundle.backup_bundles is not None
568                         ):
569                             score = base_score
570
571                             # Schedule backups of bundles running
572                             # longer; especially those that are
573                             # unexpectedly slow.
574                             start_ts = self.status.start_per_bundle[uuid]
575                             if start_ts is not None:
576                                 runtime = now - start_ts
577                                 score += runtime
578                                 logger.debug(
579                                     f'score[{bundle}] => {score}  # latency boost'
580                                 )
581
582                                 if bundle.slower_than_local_p95:
583                                     score += runtime / 2
584                                     logger.debug(
585                                         f'score[{bundle}] => {score}  # >worker p95'
586                                     )
587
588                                 if bundle.slower_than_global_p95:
589                                     score += runtime / 4
590                                     logger.debug(
591                                         f'score[{bundle}] => {score}  # >global p95'
592                                     )
593
594                             # Prefer backups of bundles that don't
595                             # have backups already.
596                             backup_count = len(bundle.backup_bundles)
597                             if backup_count == 0:
598                                 score *= 2
599                             elif backup_count == 1:
600                                 score /= 2
601                             elif backup_count == 2:
602                                 score /= 8
603                             else:
604                                 score = 0
605                             logger.debug(
606                                 f'score[{bundle}] => {score}  # {backup_count} dup backup factor'
607                             )
608
609                             if score != 0 and (
610                                 best_score is None or score > best_score
611                             ):
612                                 bundle_to_backup = bundle
613                                 assert bundle is not None
614                                 assert bundle.backup_bundles is not None
615                                 assert bundle.src_bundle is None
616                                 best_score = score
617
618                 # Note: this is all still happening on the heartbeat
619                 # runner thread.  That's ok because
620                 # schedule_backup_for_bundle uses the executor to
621                 # submit the bundle again which will cause it to be
622                 # picked up by a worker thread and allow this thread
623                 # to return to run future heartbeats.
624                 if bundle_to_backup is not None:
625                     self.last_backup = now
626                     logger.info(
627                         f'=====> SCHEDULING BACKUP {bundle_to_backup} (score={best_score:.1f}) <====='
628                     )
629                     self.schedule_backup_for_bundle(bundle_to_backup)
630             finally:
631                 self.backup_lock.release()
632
633     def is_worker_available(self) -> bool:
634         return self.policy.is_worker_available()
635
636     def acquire_worker(
637         self, machine_to_avoid: str = None
638     ) -> Optional[RemoteWorkerRecord]:
639         return self.policy.acquire_worker(machine_to_avoid)
640
641     def find_available_worker_or_block(
642         self, machine_to_avoid: str = None
643     ) -> RemoteWorkerRecord:
644         with self.cv:
645             while not self.is_worker_available():
646                 self.cv.wait()
647             worker = self.acquire_worker(machine_to_avoid)
648             if worker is not None:
649                 return worker
650         msg = "We should never reach this point in the code"
651         logger.critical(msg)
652         raise Exception(msg)
653
654     def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
655         worker = bundle.worker
656         assert worker is not None
657         logger.debug(f'Released worker {worker}')
658         self.status.record_release_worker(
659             worker,
660             bundle.uuid,
661             was_cancelled,
662         )
663         with self.cv:
664             worker.count += 1
665             self.cv.notify()
666         self.adjust_task_count(-1)
667
668     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
669         with self.status.lock:
670             if bundle.is_cancelled.wait(timeout=0.0):
671                 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
672                 bundle.was_cancelled = True
673                 return True
674         return False
675
676     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
677         """Find a worker for bundle or block until one is available."""
678         self.adjust_task_count(+1)
679         uuid = bundle.uuid
680         hostname = bundle.hostname
681         avoid_machine = override_avoid_machine
682         is_original = bundle.src_bundle is None
683
684         # Try not to schedule a backup on the same host as the original.
685         if avoid_machine is None and bundle.src_bundle is not None:
686             avoid_machine = bundle.src_bundle.machine
687         worker = None
688         while worker is None:
689             worker = self.find_available_worker_or_block(avoid_machine)
690         assert worker
691
692         # Ok, found a worker.
693         bundle.worker = worker
694         machine = bundle.machine = worker.machine
695         username = bundle.username = worker.username
696         self.status.record_acquire_worker(worker, uuid)
697         logger.debug(f'{bundle}: Running bundle on {worker}...')
698
699         # Before we do any work, make sure the bundle is still viable.
700         # It may have been some time between when it was submitted and
701         # now due to lack of worker availability and someone else may
702         # have already finished it.
703         if self.check_if_cancelled(bundle):
704             try:
705                 return self.process_work_result(bundle)
706             except Exception as e:
707                 logger.warning(
708                     f'{bundle}: bundle says it\'s cancelled upfront but no results?!'
709                 )
710                 self.release_worker(bundle)
711                 if is_original:
712                     # Weird.  We are the original owner of this
713                     # bundle.  For it to have been cancelled, a backup
714                     # must have already started and completed before
715                     # we even for started.  Moreover, the backup says
716                     # it is done but we can't find the results it
717                     # should have copied over.  Reschedule the whole
718                     # thing.
719                     logger.exception(e)
720                     logger.error(
721                         f'{bundle}: We are the original owner thread and yet there are '
722                         + 'no results for this bundle.  This is unexpected and bad.'
723                     )
724                     return self.emergency_retry_nasty_bundle(bundle)
725                 else:
726                     # Expected(?).  We're a backup and our bundle is
727                     # cancelled before we even got started.  Something
728                     # went bad in process_work_result (I acutually don't
729                     # see what?) but probably not worth worrying
730                     # about.  Let the original thread worry about
731                     # either finding the results or complaining about
732                     # it.
733                     return None
734
735         # Send input code / data to worker machine if it's not local.
736         if hostname not in machine:
737             try:
738                 cmd = (
739                     f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
740                 )
741                 start_ts = time.time()
742                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
743                 run_silently(cmd)
744                 xfer_latency = time.time() - start_ts
745                 logger.debug(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.")
746             except Exception as e:
747                 self.release_worker(bundle)
748                 if is_original:
749                     # Weird.  We tried to copy the code to the worker and it failed...
750                     # And we're the original bundle.  We have to retry.
751                     logger.exception(e)
752                     logger.error(
753                         f"{bundle}: Failed to send instructions to the worker machine?! "
754                         + "This is not expected; we\'re the original bundle so this shouldn\'t "
755                         + "be a race condition.  Attempting an emergency retry..."
756                     )
757                     return self.emergency_retry_nasty_bundle(bundle)
758                 else:
759                     # This is actually expected; we're a backup.
760                     # There's a race condition where someone else
761                     # already finished the work and removed the source
762                     # code file before we could copy it.  No biggie.
763                     msg = f'{bundle}: Failed to send instructions to the worker machine... '
764                     msg += 'We\'re a backup and this may be caused by the original (or some '
765                     msg += 'other backup) already finishing this work.  Ignoring this.'
766                     logger.warning(msg)
767                     return None
768
769         # Kick off the work.  Note that if this fails we let
770         # wait_for_process deal with it.
771         self.status.record_processing_began(uuid)
772         cmd = (
773             f'{SSH} {bundle.username}@{bundle.machine} '
774             f'"source py38-venv/bin/activate &&'
775             f' /home/scott/lib/python_modules/remote_worker.py'
776             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
777         )
778         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
779         p = cmd_in_background(cmd, silent=True)
780         bundle.pid = p.pid
781         logger.debug(
782             f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.'
783         )
784         return self.wait_for_process(p, bundle, 0)
785
786     def wait_for_process(
787         self, p: subprocess.Popen, bundle: BundleDetails, depth: int
788     ) -> Any:
789         machine = bundle.machine
790         pid = p.pid
791         if depth > 3:
792             logger.error(
793                 f"I've gotten repeated errors waiting on this bundle; giving up on pid={pid}."
794             )
795             p.terminate()
796             self.release_worker(bundle)
797             return self.emergency_retry_nasty_bundle(bundle)
798
799         # Spin until either the ssh job we scheduled finishes the
800         # bundle or some backup worker signals that they finished it
801         # before we could.
802         while True:
803             try:
804                 p.wait(timeout=0.25)
805             except subprocess.TimeoutExpired:
806                 if self.check_if_cancelled(bundle):
807                     logger.info(
808                         f'{bundle}: looks like another worker finished bundle...'
809                     )
810                     break
811             else:
812                 logger.info(f"{bundle}: pid {pid} ({machine}) is finished!")
813                 p = None
814                 break
815
816         # If we get here we believe the bundle is done; either the ssh
817         # subprocess finished (hopefully successfully) or we noticed
818         # that some other worker seems to have completed the bundle
819         # and we're bailing out.
820         try:
821             ret = self.process_work_result(bundle)
822             if ret is not None and p is not None:
823                 p.terminate()
824             return ret
825
826         # Something went wrong; e.g. we could not copy the results
827         # back, cleanup after ourselves on the remote machine, or
828         # unpickle the results we got from the remove machine.  If we
829         # still have an active ssh subprocess, keep waiting on it.
830         # Otherwise, time for an emergency reschedule.
831         except Exception as e:
832             logger.exception(e)
833             logger.error(f'{bundle}: Something unexpected just happened...')
834             if p is not None:
835                 msg = f"{bundle}: Failed to wrap up \"done\" bundle, re-waiting on active ssh."
836                 logger.warning(msg)
837                 return self.wait_for_process(p, bundle, depth + 1)
838             else:
839                 self.release_worker(bundle)
840                 return self.emergency_retry_nasty_bundle(bundle)
841
842     def process_work_result(self, bundle: BundleDetails) -> Any:
843         with self.status.lock:
844             is_original = bundle.src_bundle is None
845             was_cancelled = bundle.was_cancelled
846             username = bundle.username
847             machine = bundle.machine
848             result_file = bundle.result_file
849             code_file = bundle.code_file
850
851             # Whether original or backup, if we finished first we must
852             # fetch the results if the computation happened on a
853             # remote machine.
854             bundle.end_ts = time.time()
855             if not was_cancelled:
856                 assert bundle.machine is not None
857                 if bundle.hostname not in bundle.machine:
858                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
859                     logger.info(
860                         f"{bundle}: Fetching results back from {username}@{machine} via {cmd}"
861                     )
862
863                     # If either of these throw they are handled in
864                     # wait_for_process.
865                     attempts = 0
866                     while True:
867                         try:
868                             run_silently(cmd)
869                         except Exception as e:
870                             attempts += 1
871                             if attempts >= 3:
872                                 raise (e)
873                         else:
874                             break
875
876                     run_silently(
877                         f'{SSH} {username}@{machine}'
878                         f' "/bin/rm -f {code_file} {result_file}"'
879                     )
880                     logger.debug(
881                         f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.'
882                     )
883                 dur = bundle.end_ts - bundle.start_ts
884                 self.histogram.add_item(dur)
885
886         # Only the original worker should unpickle the file contents
887         # though since it's the only one whose result matters.  The
888         # original is also the only job that may delete result_file
889         # from disk.  Note that the original may have been cancelled
890         # if one of the backups finished first; it still must read the
891         # result from disk.
892         if is_original:
893             logger.debug(f"{bundle}: Unpickling {result_file}.")
894             try:
895                 with open(result_file, 'rb') as rb:
896                     serialized = rb.read()
897                 result = cloudpickle.loads(serialized)
898             except Exception as e:
899                 logger.exception(e)
900                 msg = f'Failed to load {result_file}... this is bad news.'
901                 logger.critical(msg)
902                 self.release_worker(bundle)
903
904                 # Re-raise the exception; the code in wait_for_process may
905                 # decide to emergency_retry_nasty_bundle here.
906                 raise Exception(e)
907             logger.debug(f'Removing local (master) {code_file} and {result_file}.')
908             os.remove(f'{result_file}')
909             os.remove(f'{code_file}')
910
911             # Notify any backups that the original is done so they
912             # should stop ASAP.  Do this whether or not we
913             # finished first since there could be more than one
914             # backup.
915             if bundle.backup_bundles is not None:
916                 for backup in bundle.backup_bundles:
917                     logger.debug(
918                         f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled'
919                     )
920                     backup.is_cancelled.set()
921
922         # This is a backup job and, by now, we have already fetched
923         # the bundle results.
924         else:
925             # Backup results don't matter, they just need to leave the
926             # result file in the right place for their originals to
927             # read/unpickle later.
928             result = None
929
930             # Tell the original to stop if we finished first.
931             if not was_cancelled:
932                 logger.debug(
933                     f'{bundle}: Notifying original {bundle.src_bundle.uuid} we beat them to it.'
934                 )
935                 bundle.src_bundle.is_cancelled.set()
936         self.release_worker(bundle, was_cancelled=was_cancelled)
937         return result
938
939     def create_original_bundle(self, pickle, fname: str):
940         from string_utils import generate_uuid
941
942         uuid = generate_uuid(omit_dashes=True)
943         code_file = f'/tmp/{uuid}.code.bin'
944         result_file = f'/tmp/{uuid}.result.bin'
945
946         logger.debug(f'Writing pickled code to {code_file}')
947         with open(f'{code_file}', 'wb') as wb:
948             wb.write(pickle)
949
950         bundle = BundleDetails(
951             pickled_code=pickle,
952             uuid=uuid,
953             fname=fname,
954             worker=None,
955             username=None,
956             machine=None,
957             hostname=platform.node(),
958             code_file=code_file,
959             result_file=result_file,
960             pid=0,
961             start_ts=time.time(),
962             end_ts=0.0,
963             slower_than_local_p95=False,
964             slower_than_global_p95=False,
965             src_bundle=None,
966             is_cancelled=threading.Event(),
967             was_cancelled=False,
968             backup_bundles=[],
969             failure_count=0,
970         )
971         self.status.record_bundle_details(bundle)
972         logger.debug(f'{bundle}: Created an original bundle')
973         return bundle
974
975     def create_backup_bundle(self, src_bundle: BundleDetails):
976         assert src_bundle.backup_bundles is not None
977         n = len(src_bundle.backup_bundles)
978         uuid = src_bundle.uuid + f'_backup#{n}'
979
980         backup_bundle = BundleDetails(
981             pickled_code=src_bundle.pickled_code,
982             uuid=uuid,
983             fname=src_bundle.fname,
984             worker=None,
985             username=None,
986             machine=None,
987             hostname=src_bundle.hostname,
988             code_file=src_bundle.code_file,
989             result_file=src_bundle.result_file,
990             pid=0,
991             start_ts=time.time(),
992             end_ts=0.0,
993             slower_than_local_p95=False,
994             slower_than_global_p95=False,
995             src_bundle=src_bundle,
996             is_cancelled=threading.Event(),
997             was_cancelled=False,
998             backup_bundles=None,  # backup backups not allowed
999             failure_count=0,
1000         )
1001         src_bundle.backup_bundles.append(backup_bundle)
1002         self.status.record_bundle_details_already_locked(backup_bundle)
1003         logger.debug(f'{backup_bundle}: Created a backup bundle')
1004         return backup_bundle
1005
1006     def schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1007         assert self.status.lock.locked()
1008         assert src_bundle is not None
1009         backup_bundle = self.create_backup_bundle(src_bundle)
1010         logger.debug(
1011             f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
1012         )
1013         self._helper_executor.submit(self.launch, backup_bundle)
1014
1015         # Results from backups don't matter; if they finish first
1016         # they will move the result_file to this machine and let
1017         # the original pick them up and unpickle them.
1018
1019     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> fut.Future:
1020         is_original = bundle.src_bundle is None
1021         bundle.worker = None
1022         avoid_last_machine = bundle.machine
1023         bundle.machine = None
1024         bundle.username = None
1025         bundle.failure_count += 1
1026         if is_original:
1027             retry_limit = 3
1028         else:
1029             retry_limit = 2
1030
1031         if bundle.failure_count > retry_limit:
1032             logger.error(
1033                 f'{bundle}: Tried this bundle too many times already ({retry_limit}x); giving up.'
1034             )
1035             if is_original:
1036                 raise RemoteExecutorException(
1037                     f'{bundle}: This bundle can\'t be completed despite several backups and retries'
1038                 )
1039             else:
1040                 logger.error(
1041                     f'{bundle}: At least it\'s only a backup; better luck with the others.'
1042                 )
1043             return None
1044         else:
1045             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1046             logger.warning(msg)
1047             warnings.warn(msg)
1048             return self.launch(bundle, avoid_last_machine)
1049
1050     @overrides
1051     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1052         pickle = make_cloud_pickle(function, *args, **kwargs)
1053         bundle = self.create_original_bundle(pickle, function.__name__)
1054         self.total_bundles_submitted += 1
1055         return self._helper_executor.submit(self.launch, bundle)
1056
1057     @overrides
1058     def shutdown(self, wait=True) -> None:
1059         logging.debug(f'Shutting down RemoteExecutor {self.title}')
1060         self.heartbeat_stop_event.set()
1061         self.heartbeat_thread.join()
1062         self._helper_executor.shutdown(wait)
1063         print(self.histogram)
1064
1065
1066 @singleton
1067 class DefaultExecutors(object):
1068     def __init__(self):
1069         self.thread_executor: Optional[ThreadExecutor] = None
1070         self.process_executor: Optional[ProcessExecutor] = None
1071         self.remote_executor: Optional[RemoteExecutor] = None
1072
1073     def ping(self, host) -> bool:
1074         logger.debug(f'RUN> ping -c 1 {host}')
1075         try:
1076             x = cmd_with_timeout(
1077                 f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0
1078             )
1079             return x == 0
1080         except Exception:
1081             return False
1082
1083     def thread_pool(self) -> ThreadExecutor:
1084         if self.thread_executor is None:
1085             self.thread_executor = ThreadExecutor()
1086         return self.thread_executor
1087
1088     def process_pool(self) -> ProcessExecutor:
1089         if self.process_executor is None:
1090             self.process_executor = ProcessExecutor()
1091         return self.process_executor
1092
1093     def remote_pool(self) -> RemoteExecutor:
1094         if self.remote_executor is None:
1095             logger.info('Looking for some helper machines...')
1096             pool: List[RemoteWorkerRecord] = []
1097             if self.ping('cheetah.house'):
1098                 logger.info('Found cheetah.house')
1099                 pool.append(
1100                     RemoteWorkerRecord(
1101                         username='scott',
1102                         machine='cheetah.house',
1103                         weight=25,
1104                         count=6,
1105                     ),
1106                 )
1107             if self.ping('meerkat.cabin'):
1108                 logger.info('Found meerkat.cabin')
1109                 pool.append(
1110                     RemoteWorkerRecord(
1111                         username='scott',
1112                         machine='meerkat.cabin',
1113                         weight=12,
1114                         count=2,
1115                     ),
1116                 )
1117             if self.ping('wannabe.house'):
1118                 logger.info('Found wannabe.house')
1119                 pool.append(
1120                     RemoteWorkerRecord(
1121                         username='scott',
1122                         machine='wannabe.house',
1123                         weight=30,
1124                         count=10,
1125                     ),
1126                 )
1127             if self.ping('puma.cabin'):
1128                 logger.info('Found puma.cabin')
1129                 pool.append(
1130                     RemoteWorkerRecord(
1131                         username='scott',
1132                         machine='puma.cabin',
1133                         weight=25,
1134                         count=6,
1135                     ),
1136                 )
1137             if self.ping('backup.house'):
1138                 logger.info('Found backup.house')
1139                 pool.append(
1140                     RemoteWorkerRecord(
1141                         username='scott',
1142                         machine='backup.house',
1143                         weight=7,
1144                         count=2,
1145                     ),
1146                 )
1147
1148             # The controller machine has a lot to do; go easy on it.
1149             for record in pool:
1150                 if record.machine == platform.node() and record.count > 1:
1151                     logger.info(f'Reducing workload for {record.machine}.')
1152                     record.count = 1
1153
1154             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1155             policy.register_worker_pool(pool)
1156             self.remote_executor = RemoteExecutor(pool, policy)
1157         return self.remote_executor
1158
1159     def shutdown(self) -> None:
1160         if self.thread_executor is not None:
1161             self.thread_executor.shutdown()
1162             self.thread_executor = None
1163         if self.process_executor is not None:
1164             self.process_executor.shutdown()
1165             self.process_executor = None
1166         if self.remote_executor is not None:
1167             self.remote_executor.shutdown()
1168             self.remote_executor = None