Make smart futures avoid polling.
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import concurrent.futures as fut
7 from collections import defaultdict
8 from dataclasses import dataclass
9 import logging
10 import numpy
11 import os
12 import platform
13 import random
14 import subprocess
15 import threading
16 import time
17 from typing import Any, Callable, Dict, List, Optional, Set
18
19 import cloudpickle  # type: ignore
20 from overrides import overrides
21
22 from ansi import bg, fg, underline, reset
23 import argparse_utils
24 import config
25 from decorator_utils import singleton
26 from exec_utils import run_silently, cmd_in_background, cmd_with_timeout
27 import histogram as hist
28
29 logger = logging.getLogger(__name__)
30
31 parser = config.add_commandline_args(
32     f"Executors ({__file__})",
33     "Args related to processing executors."
34 )
35 parser.add_argument(
36     '--executors_threadpool_size',
37     type=int,
38     metavar='#THREADS',
39     help='Number of threads in the default threadpool, leave unset for default',
40     default=None
41 )
42 parser.add_argument(
43     '--executors_processpool_size',
44     type=int,
45     metavar='#PROCESSES',
46     help='Number of processes in the default processpool, leave unset for default',
47     default=None,
48 )
49 parser.add_argument(
50     '--executors_schedule_remote_backups',
51     default=True,
52     action=argparse_utils.ActionNoYes,
53     help='Should we schedule duplicative backup work if a remote bundle is slow',
54 )
55 parser.add_argument(
56     '--executors_max_bundle_failures',
57     type=int,
58     default=3,
59     metavar='#FAILURES',
60     help='Maximum number of failures before giving up on a bundle',
61 )
62
63 RSYNC = 'rsync -q --no-motd -W --ignore-existing --timeout=60 --size-only -z'
64 SSH = 'ssh -oForwardX11=no'
65
66
67 def make_cloud_pickle(fun, *args, **kwargs):
68     logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
69     return cloudpickle.dumps((fun, args, kwargs))
70
71
72 class BaseExecutor(ABC):
73     def __init__(self, *, title=''):
74         self.title = title
75         self.task_count = 0
76         self.histogram = hist.SimpleHistogram(
77             hist.SimpleHistogram.n_evenly_spaced_buckets(
78                 int(0), int(500), 50
79             )
80         )
81
82     @abstractmethod
83     def submit(self,
84                function: Callable,
85                *args,
86                **kwargs) -> fut.Future:
87         pass
88
89     @abstractmethod
90     def shutdown(self,
91                  wait: bool = True) -> None:
92         pass
93
94     def adjust_task_count(self, delta: int) -> None:
95         self.task_count += delta
96         logger.debug(f'Executor current task count is {self.task_count}')
97
98
99 class ThreadExecutor(BaseExecutor):
100     def __init__(self,
101                  max_workers: Optional[int] = None):
102         super().__init__()
103         workers = None
104         if max_workers is not None:
105             workers = max_workers
106         elif 'executors_threadpool_size' in config.config:
107             workers = config.config['executors_threadpool_size']
108         logger.debug(f'Creating threadpool executor with {workers} workers')
109         self._thread_pool_executor = fut.ThreadPoolExecutor(
110             max_workers=workers,
111             thread_name_prefix="thread_executor_helper"
112         )
113
114     def run_local_bundle(self, fun, *args, **kwargs):
115         logger.debug(f"Running local bundle at {fun.__name__}")
116         start = time.time()
117         result = fun(*args, **kwargs)
118         end = time.time()
119         self.adjust_task_count(-1)
120         duration = end - start
121         logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
122         self.histogram.add_item(duration)
123         return result
124
125     @overrides
126     def submit(self,
127                function: Callable,
128                *args,
129                **kwargs) -> fut.Future:
130         self.adjust_task_count(+1)
131         newargs = []
132         newargs.append(function)
133         for arg in args:
134             newargs.append(arg)
135         return self._thread_pool_executor.submit(
136             self.run_local_bundle,
137             *newargs,
138             **kwargs)
139
140     @overrides
141     def shutdown(self,
142                  wait = True) -> None:
143         logger.debug(f'Shutting down threadpool executor {self.title}')
144         print(self.histogram)
145         self._thread_pool_executor.shutdown(wait)
146
147
148 class ProcessExecutor(BaseExecutor):
149     def __init__(self,
150                  max_workers=None):
151         super().__init__()
152         workers = None
153         if max_workers is not None:
154             workers = max_workers
155         elif 'executors_processpool_size' in config.config:
156             workers = config.config['executors_processpool_size']
157         logger.debug(f'Creating processpool executor with {workers} workers.')
158         self._process_executor = fut.ProcessPoolExecutor(
159             max_workers=workers,
160         )
161
162     def run_cloud_pickle(self, pickle):
163         fun, args, kwargs = cloudpickle.loads(pickle)
164         logger.debug(f"Running pickled bundle at {fun.__name__}")
165         result = fun(*args, **kwargs)
166         self.adjust_task_count(-1)
167         return result
168
169     @overrides
170     def submit(self,
171                function: Callable,
172                *args,
173                **kwargs) -> fut.Future:
174         start = time.time()
175         self.adjust_task_count(+1)
176         pickle = make_cloud_pickle(function, *args, **kwargs)
177         result = self._process_executor.submit(
178             self.run_cloud_pickle,
179             pickle
180         )
181         result.add_done_callback(
182             lambda _: self.histogram.add_item(
183                 time.time() - start
184             )
185         )
186         return result
187
188     @overrides
189     def shutdown(self, wait=True) -> None:
190         logger.debug(f'Shutting down processpool executor {self.title}')
191         self._process_executor.shutdown(wait)
192         print(self.histogram)
193
194     def __getstate__(self):
195         state = self.__dict__.copy()
196         state['_process_executor'] = None
197         return state
198
199
200 class RemoteExecutorException(Exception):
201     """Thrown when a bundle cannot be executed despite several retries."""
202     pass
203
204
205 @dataclass
206 class RemoteWorkerRecord:
207     username: str
208     machine: str
209     weight: int
210     count: int
211
212     def __hash__(self):
213         return hash((self.username, self.machine))
214
215     def __repr__(self):
216         return f'{self.username}@{self.machine}'
217
218
219 @dataclass
220 class BundleDetails:
221     pickled_code: bytes
222     uuid: str
223     fname: str
224     worker: Optional[RemoteWorkerRecord]
225     username: Optional[str]
226     machine: Optional[str]
227     hostname: str
228     code_file: str
229     result_file: str
230     pid: int
231     start_ts: float
232     end_ts: float
233     slower_than_local_p95: bool
234     slower_than_global_p95: bool
235     src_bundle: BundleDetails
236     is_cancelled: threading.Event
237     was_cancelled: bool
238     backup_bundles: Optional[List[BundleDetails]]
239     failure_count: int
240
241     def __repr__(self):
242         uuid = self.uuid
243         if uuid[-9:-2] == '_backup':
244             uuid = uuid[:-9]
245             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
246         else:
247             suffix = uuid[-6:]
248
249         colorz = [
250             fg('violet red'),
251             fg('red'),
252             fg('orange'),
253             fg('peach orange'),
254             fg('yellow'),
255             fg('marigold yellow'),
256             fg('green yellow'),
257             fg('tea green'),
258             fg('cornflower blue'),
259             fg('turquoise blue'),
260             fg('tropical blue'),
261             fg('lavender purple'),
262             fg('medium purple'),
263         ]
264         c = colorz[int(uuid[-2:], 16) % len(colorz)]
265         fname = self.fname if self.fname is not None else 'nofname'
266         machine = self.machine if self.machine is not None else 'nomachine'
267         return f'{c}{suffix}/{fname}/{machine}{reset()}'
268
269
270 class RemoteExecutorStatus:
271     def __init__(self, total_worker_count: int) -> None:
272         self.worker_count = total_worker_count
273         self.known_workers: Set[RemoteWorkerRecord] = set()
274         self.start_per_bundle: Dict[str, float] = defaultdict(float)
275         self.end_per_bundle: Dict[str, float] = defaultdict(float)
276         self.finished_bundle_timings_per_worker: Dict[
277             RemoteWorkerRecord,
278             List[float]
279         ] = {}
280         self.in_flight_bundles_by_worker: Dict[
281             RemoteWorkerRecord,
282             Set[str]
283         ] = {}
284         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
285         self.finished_bundle_timings: List[float] = []
286         self.last_periodic_dump: Optional[float] = None
287         self.total_bundles_submitted = 0
288
289         # Protects reads and modification using self.  Also used
290         # as a memory fence for modifications to bundle.
291         self.lock = threading.Lock()
292
293     def record_acquire_worker(
294             self,
295             worker: RemoteWorkerRecord,
296             uuid: str
297     ) -> None:
298         with self.lock:
299             self.record_acquire_worker_already_locked(
300                 worker,
301                 uuid
302             )
303
304     def record_acquire_worker_already_locked(
305             self,
306             worker: RemoteWorkerRecord,
307             uuid: str
308     ) -> None:
309         assert self.lock.locked()
310         self.known_workers.add(worker)
311         self.start_per_bundle[uuid] = None
312         x = self.in_flight_bundles_by_worker.get(worker, set())
313         x.add(uuid)
314         self.in_flight_bundles_by_worker[worker] = x
315
316     def record_bundle_details(
317             self,
318             details: BundleDetails) -> None:
319         with self.lock:
320             self.record_bundle_details_already_locked(details)
321
322     def record_bundle_details_already_locked(
323             self,
324             details: BundleDetails) -> None:
325         assert self.lock.locked()
326         self.bundle_details_by_uuid[details.uuid] = details
327
328     def record_release_worker(
329             self,
330             worker: RemoteWorkerRecord,
331             uuid: str,
332             was_cancelled: bool,
333     ) -> None:
334         with self.lock:
335             self.record_release_worker_already_locked(
336                 worker, uuid, was_cancelled
337             )
338
339     def record_release_worker_already_locked(
340             self,
341             worker: RemoteWorkerRecord,
342             uuid: str,
343             was_cancelled: bool,
344     ) -> None:
345         assert self.lock.locked()
346         ts = time.time()
347         self.end_per_bundle[uuid] = ts
348         self.in_flight_bundles_by_worker[worker].remove(uuid)
349         if not was_cancelled:
350             bundle_latency = ts - self.start_per_bundle[uuid]
351             x = self.finished_bundle_timings_per_worker.get(worker, list())
352             x.append(bundle_latency)
353             self.finished_bundle_timings_per_worker[worker] = x
354             self.finished_bundle_timings.append(bundle_latency)
355
356     def record_processing_began(self, uuid: str):
357         with self.lock:
358             self.start_per_bundle[uuid] = time.time()
359
360     def total_in_flight(self) -> int:
361         assert self.lock.locked()
362         total_in_flight = 0
363         for worker in self.known_workers:
364             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
365         return total_in_flight
366
367     def total_idle(self) -> int:
368         assert self.lock.locked()
369         return self.worker_count - self.total_in_flight()
370
371     def __repr__(self):
372         assert self.lock.locked()
373         ts = time.time()
374         total_finished = len(self.finished_bundle_timings)
375         total_in_flight = self.total_in_flight()
376         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
377         qall = None
378         if len(self.finished_bundle_timings) > 1:
379             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
380             ret += (
381                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, '
382                 f'✅={total_finished}/{self.total_bundles_submitted}, '
383                 f'💻n={total_in_flight}/{self.worker_count}\n'
384             )
385         else:
386             ret += (
387                 f' ✅={total_finished}/{self.total_bundles_submitted}, '
388                 f'💻n={total_in_flight}/{self.worker_count}\n'
389             )
390
391         for worker in self.known_workers:
392             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
393             timings = self.finished_bundle_timings_per_worker.get(worker, [])
394             count = len(timings)
395             qworker = None
396             if count > 1:
397                 qworker = numpy.quantile(timings, [0.5, 0.95])
398                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
399             else:
400                 ret += '\n'
401             if count > 0:
402                 ret += f'    ...finished {count} total bundle(s) so far\n'
403             in_flight = len(self.in_flight_bundles_by_worker[worker])
404             if in_flight > 0:
405                 ret += f'    ...{in_flight} bundles currently in flight:\n'
406                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
407                     details = self.bundle_details_by_uuid.get(
408                         bundle_uuid,
409                         None
410                     )
411                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
412                     if self.start_per_bundle[bundle_uuid] is not None:
413                         sec = ts - self.start_per_bundle[bundle_uuid]
414                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
415                     else:
416                         ret += f'       {details} setting up / copying data...'
417                         sec = 0.0
418
419                     if qworker is not None:
420                         if sec > qworker[1]:
421                             ret += f'{bg("red")}>💻p95{reset()} '
422                             if details is not None:
423                                 details.slower_than_local_p95 = True
424                         else:
425                             if details is not None:
426                                 details.slower_than_local_p95 = False
427
428                     if qall is not None:
429                         if sec > qall[1]:
430                             ret += f'{bg("red")}>∀p95{reset()} '
431                             if details is not None:
432                                 details.slower_than_global_p95 = True
433                         else:
434                             details.slower_than_global_p95 = False
435                     ret += '\n'
436         return ret
437
438     def periodic_dump(self, total_bundles_submitted: int) -> None:
439         assert self.lock.locked()
440         self.total_bundles_submitted = total_bundles_submitted
441         ts = time.time()
442         if (
443                 self.last_periodic_dump is None
444                 or ts - self.last_periodic_dump > 5.0
445         ):
446             print(self)
447             self.last_periodic_dump = ts
448
449
450 class RemoteWorkerSelectionPolicy(ABC):
451     def register_worker_pool(self, workers):
452         self.workers = workers
453
454     @abstractmethod
455     def is_worker_available(self) -> bool:
456         pass
457
458     @abstractmethod
459     def acquire_worker(
460             self,
461             machine_to_avoid = None
462     ) -> Optional[RemoteWorkerRecord]:
463         pass
464
465
466 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
467     @overrides
468     def is_worker_available(self) -> bool:
469         for worker in self.workers:
470             if worker.count > 0:
471                 return True
472         return False
473
474     @overrides
475     def acquire_worker(
476             self,
477             machine_to_avoid = None
478     ) -> Optional[RemoteWorkerRecord]:
479         grabbag = []
480         for worker in self.workers:
481             for x in range(0, worker.count):
482                 for y in range(0, worker.weight):
483                     grabbag.append(worker)
484
485         for _ in range(0, 5):
486             random.shuffle(grabbag)
487             worker = grabbag[0]
488             if worker.machine != machine_to_avoid or _ > 2:
489                 if worker.count > 0:
490                     worker.count -= 1
491                     logger.debug(f'Selected worker {worker}')
492                     return worker
493         logger.warning("Couldn't find a worker; go fish.")
494         return None
495
496
497 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
498     def __init__(self) -> None:
499         self.index = 0
500
501     @overrides
502     def is_worker_available(self) -> bool:
503         for worker in self.workers:
504             if worker.count > 0:
505                 return True
506         return False
507
508     @overrides
509     def acquire_worker(
510             self,
511             machine_to_avoid: str = None
512     ) -> Optional[RemoteWorkerRecord]:
513         x = self.index
514         while True:
515             worker = self.workers[x]
516             if worker.count > 0:
517                 worker.count -= 1
518                 x += 1
519                 if x >= len(self.workers):
520                     x = 0
521                 self.index = x
522                 logger.debug(f'Selected worker {worker}')
523                 return worker
524             x += 1
525             if x >= len(self.workers):
526                 x = 0
527             if x == self.index:
528                 logger.warning("Couldn't find a worker; go fish.")
529                 return None
530
531
532 class RemoteExecutor(BaseExecutor):
533     def __init__(self,
534                  workers: List[RemoteWorkerRecord],
535                  policy: RemoteWorkerSelectionPolicy) -> None:
536         super().__init__()
537         self.workers = workers
538         self.policy = policy
539         self.worker_count = 0
540         for worker in self.workers:
541             self.worker_count += worker.count
542         if self.worker_count <= 0:
543             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
544             logger.critical(msg)
545             raise RemoteExecutorException(msg)
546         self.policy.register_worker_pool(self.workers)
547         self.cv = threading.Condition()
548         logger.debug(f'Creating {self.worker_count} local threads, one per remote worker.')
549         self._helper_executor = fut.ThreadPoolExecutor(
550             thread_name_prefix="remote_executor_helper",
551             max_workers=self.worker_count,
552         )
553         self.status = RemoteExecutorStatus(self.worker_count)
554         self.total_bundles_submitted = 0
555         self.backup_lock = threading.Lock()
556         self.last_backup = None
557
558     def is_worker_available(self) -> bool:
559         return self.policy.is_worker_available()
560
561     def acquire_worker(
562             self,
563             machine_to_avoid: str = None
564     ) -> Optional[RemoteWorkerRecord]:
565         return self.policy.acquire_worker(machine_to_avoid)
566
567     def find_available_worker_or_block(
568             self,
569             machine_to_avoid: str = None
570     ) -> RemoteWorkerRecord:
571         with self.cv:
572             while not self.is_worker_available():
573                 self.cv.wait()
574             worker = self.acquire_worker(machine_to_avoid)
575             if worker is not None:
576                 return worker
577         msg = "We should never reach this point in the code"
578         logger.critical(msg)
579         raise Exception(msg)
580
581     def release_worker(self, worker: RemoteWorkerRecord) -> None:
582         logger.debug(f'Released worker {worker}')
583         with self.cv:
584             worker.count += 1
585             self.cv.notify()
586
587     def heartbeat(self) -> None:
588         with self.status.lock:
589             # Regular progress report
590             self.status.periodic_dump(self.total_bundles_submitted)
591
592             # Look for bundles to reschedule.
593             num_done = len(self.status.finished_bundle_timings)
594             num_idle_workers = self.worker_count - self.task_count
595             now = time.time()
596             if (
597                     config.config['executors_schedule_remote_backups']
598                     and num_done > 2
599                     and num_idle_workers > 1
600                     and (self.last_backup is None or (now - self.last_backup > 1.0))
601                     and self.backup_lock.acquire(blocking=False)
602             ):
603                 try:
604                     assert self.backup_lock.locked()
605
606                     bundle_to_backup = None
607                     best_score = None
608                     for worker, bundle_uuids in self.status.in_flight_bundles_by_worker.items():
609                         # Prefer to schedule backups of bundles on slower machines.
610                         base_score = 0
611                         for record in self.workers:
612                             if worker.machine == record.machine:
613                                 base_score = float(record.weight)
614                                 base_score = 1.0 / base_score
615                                 base_score *= 200.0
616                                 base_score = int(base_score)
617                                 break
618
619                         for uuid in bundle_uuids:
620                             bundle = self.status.bundle_details_by_uuid.get(uuid, None)
621                             if (
622                                     bundle is not None
623                                     and bundle.src_bundle is None
624                                     and bundle.backup_bundles is not None
625                             ):
626                                 score = base_score
627
628                                 # Schedule backups of bundles running longer; especially those
629                                 # that are unexpectedly slow.
630                                 start_ts = self.status.start_per_bundle[uuid]
631                                 if start_ts is not None:
632                                     runtime = now - start_ts
633                                     score += runtime
634                                     logger.debug(f'score[{bundle}] => {score}  # latency boost')
635
636                                     if bundle.slower_than_local_p95:
637                                         score += runtime / 2
638                                         logger.debug(f'score[{bundle}] => {score}  # >worker p95')
639
640                                     if bundle.slower_than_global_p95:
641                                         score += runtime / 2
642                                         logger.debug(f'score[{bundle}] => {score}  # >global p95')
643
644                                 # Prefer backups of bundles that don't have backups already.
645                                 backup_count = len(bundle.backup_bundles)
646                                 if backup_count == 0:
647                                     score *= 2
648                                 elif backup_count == 1:
649                                     score /= 2
650                                 elif backup_count == 2:
651                                     score /= 8
652                                 else:
653                                     score = 0
654                                 logger.debug(f'score[{bundle}] => {score}  # {backup_count} dup backup factor')
655
656                                 if (
657                                         score != 0
658                                         and (best_score is None or score > best_score)
659                                 ):
660                                     bundle_to_backup = bundle
661                                     assert bundle is not None
662                                     assert bundle.backup_bundles is not None
663                                     assert bundle.src_bundle is None
664                                     best_score = score
665
666                     if bundle_to_backup is not None:
667                         self.last_backup = now
668                         logger.info(f'=====> SCHEDULING BACKUP {bundle_to_backup} (score={best_score:.1f}) <=====')
669                         self.schedule_backup_for_bundle(bundle_to_backup)
670                 finally:
671                     self.backup_lock.release()
672
673     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
674         with self.status.lock:
675             if bundle.is_cancelled.wait(timeout=0.0):
676                 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
677                 bundle.was_cancelled = True
678                 return True
679         return False
680
681     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
682         """Find a worker for bundle or block until one is available."""
683         self.adjust_task_count(+1)
684         uuid = bundle.uuid
685         hostname = bundle.hostname
686         avoid_machine = override_avoid_machine
687         is_original = bundle.src_bundle is None
688
689         # Try not to schedule a backup on the same host as the original.
690         if avoid_machine is None and bundle.src_bundle is not None:
691             avoid_machine = bundle.src_bundle.machine
692         worker = None
693         while worker is None:
694             worker = self.find_available_worker_or_block(avoid_machine)
695
696         # Ok, found a worker.
697         bundle.worker = worker
698         machine = bundle.machine = worker.machine
699         username = bundle.username = worker.username
700
701         self.status.record_acquire_worker(worker, uuid)
702         logger.debug(f'{bundle}: Running bundle on {worker}...')
703
704         # Before we do any work, make sure the bundle is still viable.
705         if self.check_if_cancelled(bundle):
706             try:
707                 return self.post_launch_work(bundle)
708             except Exception as e:
709                 logger.exception(e)
710                 logger.error(
711                     f'{bundle}: bundle says it\'s cancelled upfront but no results?!'
712                 )
713                 assert bundle.worker is not None
714                 self.status.record_release_worker(
715                     bundle.worker,
716                     bundle.uuid,
717                     True,
718                 )
719                 self.release_worker(bundle.worker)
720                 self.adjust_task_count(-1)
721                 if is_original:
722                     # Weird.  We are the original owner of this
723                     # bundle.  For it to have been cancelled, a backup
724                     # must have already started and completed before
725                     # we even for started.  Moreover, the backup says
726                     # it is done but we can't find the results it
727                     # should have copied over.  Reschedule the whole
728                     # thing.
729                     return self.emergency_retry_nasty_bundle(bundle)
730                 else:
731                     # Expected(?).  We're a backup and our bundle is
732                     # cancelled before we even got started.  Something
733                     # went bad in post_launch_work (I acutually don't
734                     # see what?) but probably not worth worrying
735                     # about.
736                     return None
737
738         # Send input code / data to worker machine if it's not local.
739         if hostname not in machine:
740             try:
741                 cmd = f'{RSYNC} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
742                 start_ts = time.time()
743                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
744                 run_silently(cmd)
745                 xfer_latency = time.time() - start_ts
746                 logger.info(f"{bundle}: Copying done to {worker} in {xfer_latency:.1f}s.")
747             except Exception as e:
748                 logger.exception(e)
749                 logger.error(
750                     f'{bundle}: failed to send instructions to worker machine?!?'
751                 )
752                 assert bundle.worker is not None
753                 self.status.record_release_worker(
754                     bundle.worker,
755                     bundle.uuid,
756                     True,
757                 )
758                 self.release_worker(bundle.worker)
759                 self.adjust_task_count(-1)
760                 if is_original:
761                     # Weird.  We tried to copy the code to the worker and it failed...
762                     # And we're the original bundle.  We have to retry.
763                     return self.emergency_retry_nasty_bundle(bundle)
764                 else:
765                     # This is actually expected; we're a backup.
766                     # There's a race condition where someone else
767                     # already finished the work and removed the source
768                     # code file before we could copy it.  No biggie.
769                     return None
770
771         # Kick off the work.  Note that if this fails we let
772         # wait_for_process deal with it.
773         self.status.record_processing_began(uuid)
774         cmd = (f'{SSH} {bundle.username}@{bundle.machine} '
775                f'"source py39-venv/bin/activate &&'
776                f' /home/scott/lib/python_modules/remote_worker.py'
777                f' --code_file {bundle.code_file} --result_file {bundle.result_file}"')
778         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
779         p = cmd_in_background(cmd, silent=True)
780         bundle.pid = pid = p.pid
781         logger.debug(f'{bundle}: Local ssh process pid={pid}; remote worker is {machine}.')
782         return self.wait_for_process(p, bundle, 0)
783
784     def wait_for_process(self, p: subprocess.Popen, bundle: BundleDetails, depth: int) -> Any:
785         machine = bundle.machine
786         pid = p.pid
787         if depth > 3:
788             logger.error(
789                 f"I've gotten repeated errors waiting on this bundle; giving up on pid={pid}."
790             )
791             p.terminate()
792             self.status.record_release_worker(
793                 bundle.worker,
794                 bundle.uuid,
795                 True,
796             )
797             self.release_worker(bundle.worker)
798             self.adjust_task_count(-1)
799             return self.emergency_retry_nasty_bundle(bundle)
800
801         # Spin until either the ssh job we scheduled finishes the
802         # bundle or some backup worker signals that they finished it
803         # before we could.
804         while True:
805             try:
806                 p.wait(timeout=0.25)
807             except subprocess.TimeoutExpired:
808                 self.heartbeat()
809                 if self.check_if_cancelled(bundle):
810                     logger.info(
811                         f'{bundle}: another worker finished bundle, checking it out...'
812                     )
813                     break
814             else:
815                 logger.info(
816                     f"{bundle}: pid {pid} ({machine}) our ssh finished, checking it out..."
817                 )
818                 p = None
819                 break
820
821         # If we get here we believe the bundle is done; either the ssh
822         # subprocess finished (hopefully successfully) or we noticed
823         # that some other worker seems to have completed the bundle
824         # and we're bailing out.
825         try:
826             ret = self.post_launch_work(bundle)
827             if ret is not None and p is not None:
828                 p.terminate()
829             return ret
830
831         # Something went wrong; e.g. we could not copy the results
832         # back, cleanup after ourselves on the remote machine, or
833         # unpickle the results we got from the remove machine.  If we
834         # still have an active ssh subprocess, keep waiting on it.
835         # Otherwise, time for an emergency reschedule.
836         except Exception as e:
837             logger.exception(e)
838             logger.error(f'{bundle}: Something unexpected just happened...')
839             if p is not None:
840                 logger.warning(
841                     f"{bundle}: Failed to wrap up \"done\" bundle, re-waiting on active ssh."
842                 )
843                 return self.wait_for_process(p, bundle, depth + 1)
844             else:
845                 self.status.record_release_worker(
846                     bundle.worker,
847                     bundle.uuid,
848                     True,
849                 )
850                 self.release_worker(bundle.worker)
851                 self.adjust_task_count(-1)
852                 return self.emergency_retry_nasty_bundle(bundle)
853
854     def post_launch_work(self, bundle: BundleDetails) -> Any:
855         with self.status.lock:
856             is_original = bundle.src_bundle is None
857             was_cancelled = bundle.was_cancelled
858             username = bundle.username
859             machine = bundle.machine
860             result_file = bundle.result_file
861             code_file = bundle.code_file
862
863             # Whether original or backup, if we finished first we must
864             # fetch the results if the computation happened on a
865             # remote machine.
866             bundle.end_ts = time.time()
867             if not was_cancelled:
868                 assert bundle.machine is not None
869                 if bundle.hostname not in bundle.machine:
870                     cmd = f'{RSYNC} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
871                     logger.info(
872                         f"{bundle}: Fetching results from {username}@{machine} via {cmd}"
873                     )
874
875                     # If either of these throw they are handled in
876                     # wait_for_process.
877                     run_silently(cmd)
878                     run_silently(f'{SSH} {username}@{machine}'
879                                  f' "/bin/rm -f {code_file} {result_file}"')
880                 dur = bundle.end_ts - bundle.start_ts
881                 self.histogram.add_item(dur)
882
883         # Only the original worker should unpickle the file contents
884         # though since it's the only one whose result matters.  The
885         # original is also the only job that may delete result_file
886         # from disk.  Note that the original may have been cancelled
887         # if one of the backups finished first; it still must read the
888         # result from disk.
889         if is_original:
890             logger.debug(f"{bundle}: Unpickling {result_file}.")
891             try:
892                 with open(f'{result_file}', 'rb') as rb:
893                     serialized = rb.read()
894                 result = cloudpickle.loads(serialized)
895             except Exception as e:
896                 msg = f'Failed to load {result_file}, this is bad news.'
897                 logger.critical(msg)
898                 self.status.record_release_worker(
899                     bundle.worker,
900                     bundle.uuid,
901                     True,
902                 )
903                 self.release_worker(bundle.worker)
904
905                 # Re-raise the exception; the code in wait_for_process may
906                 # decide to emergency_retry_nasty_bundle here.
907                 raise Exception(e)
908
909             logger.debug(
910                 f'Removing local (master) {code_file} and {result_file}.'
911             )
912             os.remove(f'{result_file}')
913             os.remove(f'{code_file}')
914
915             # Notify any backups that the original is done so they
916             # should stop ASAP.  Do this whether or not we
917             # finished first since there could be more than one
918             # backup.
919             if bundle.backup_bundles is not None:
920                 for backup in bundle.backup_bundles:
921                     logger.debug(
922                         f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled'
923                     )
924                     backup.is_cancelled.set()
925
926         # This is a backup job and, by now, we have already fetched
927         # the bundle results.
928         else:
929             # Backup results don't matter, they just need to leave the
930             # result file in the right place for their originals to
931             # read/unpickle later.
932             result = None
933
934             # Tell the original to stop if we finished first.
935             if not was_cancelled:
936                 logger.debug(
937                     f'{bundle}: Notifying original {bundle.src_bundle.uuid} we beat them to it.'
938                 )
939                 bundle.src_bundle.is_cancelled.set()
940
941         assert bundle.worker is not None
942         self.status.record_release_worker(
943             bundle.worker,
944             bundle.uuid,
945             was_cancelled,
946         )
947         self.release_worker(bundle.worker)
948         self.adjust_task_count(-1)
949         return result
950
951     def create_original_bundle(self, pickle, fname: str):
952         from string_utils import generate_uuid
953         uuid = generate_uuid(as_hex=True)
954         code_file = f'/tmp/{uuid}.code.bin'
955         result_file = f'/tmp/{uuid}.result.bin'
956
957         logger.debug(f'Writing pickled code to {code_file}')
958         with open(f'{code_file}', 'wb') as wb:
959             wb.write(pickle)
960
961         bundle = BundleDetails(
962             pickled_code = pickle,
963             uuid = uuid,
964             fname = fname,
965             worker = None,
966             username = None,
967             machine = None,
968             hostname = platform.node(),
969             code_file = code_file,
970             result_file = result_file,
971             pid = 0,
972             start_ts = time.time(),
973             end_ts = 0.0,
974             slower_than_local_p95 = False,
975             slower_than_global_p95 = False,
976             src_bundle = None,
977             is_cancelled = threading.Event(),
978             was_cancelled = False,
979             backup_bundles = [],
980             failure_count = 0,
981         )
982         self.status.record_bundle_details(bundle)
983         logger.debug(f'{bundle}: Created an original bundle')
984         return bundle
985
986     def create_backup_bundle(self, src_bundle: BundleDetails):
987         assert src_bundle.backup_bundles is not None
988         n = len(src_bundle.backup_bundles)
989         uuid = src_bundle.uuid + f'_backup#{n}'
990
991         backup_bundle = BundleDetails(
992             pickled_code = src_bundle.pickled_code,
993             uuid = uuid,
994             fname = src_bundle.fname,
995             worker = None,
996             username = None,
997             machine = None,
998             hostname = src_bundle.hostname,
999             code_file = src_bundle.code_file,
1000             result_file = src_bundle.result_file,
1001             pid = 0,
1002             start_ts = time.time(),
1003             end_ts = 0.0,
1004             slower_than_local_p95 = False,
1005             slower_than_global_p95 = False,
1006             src_bundle = src_bundle,
1007             is_cancelled = threading.Event(),
1008             was_cancelled = False,
1009             backup_bundles = None,    # backup backups not allowed
1010             failure_count = 0,
1011         )
1012         src_bundle.backup_bundles.append(backup_bundle)
1013         self.status.record_bundle_details_already_locked(backup_bundle)
1014         logger.debug(f'{backup_bundle}: Created a backup bundle')
1015         return backup_bundle
1016
1017     def schedule_backup_for_bundle(self,
1018                                    src_bundle: BundleDetails):
1019         assert self.status.lock.locked()
1020         assert src_bundle is not None
1021         backup_bundle = self.create_backup_bundle(src_bundle)
1022         logger.debug(
1023             f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
1024         )
1025         self._helper_executor.submit(self.launch, backup_bundle)
1026
1027         # Results from backups don't matter; if they finish first
1028         # they will move the result_file to this machine and let
1029         # the original pick them up and unpickle them.
1030
1031     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> fut.Future:
1032         is_original = bundle.src_bundle is None
1033         bundle.worker = None
1034         avoid_last_machine = bundle.machine
1035         bundle.machine = None
1036         bundle.username = None
1037         bundle.failure_count += 1
1038         if is_original:
1039             retry_limit = 3
1040         else:
1041             retry_limit = 2
1042
1043         if bundle.failure_count > retry_limit:
1044             logger.error(
1045                 f'{bundle}: Tried this bundle too many times already ({retry_limit}x); giving up.'
1046             )
1047             if is_original:
1048                 raise RemoteExecutorException(
1049                     f'{bundle}: This bundle can\'t be completed despite several backups and retries'
1050                 )
1051             else:
1052                 logger.error(f'{bundle}: At least it\'s only a backup; better luck with the others.')
1053             return None
1054         else:
1055             logger.warning(
1056                 f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1057             )
1058             return self.launch(bundle, avoid_last_machine)
1059
1060     @overrides
1061     def submit(self,
1062                function: Callable,
1063                *args,
1064                **kwargs) -> fut.Future:
1065         pickle = make_cloud_pickle(function, *args, **kwargs)
1066         bundle = self.create_original_bundle(pickle, function.__name__)
1067         self.total_bundles_submitted += 1
1068         return self._helper_executor.submit(self.launch, bundle)
1069
1070     @overrides
1071     def shutdown(self, wait=True) -> None:
1072         self._helper_executor.shutdown(wait)
1073         logging.debug(f'Shutting down RemoteExecutor {self.title}')
1074         print(self.histogram)
1075
1076
1077 @singleton
1078 class DefaultExecutors(object):
1079     def __init__(self):
1080         self.thread_executor: Optional[ThreadExecutor] = None
1081         self.process_executor: Optional[ProcessExecutor] = None
1082         self.remote_executor: Optional[RemoteExecutor] = None
1083
1084     def ping(self, host) -> bool:
1085         logger.debug(f'RUN> ping -c 1 {host}')
1086         try:
1087             x = cmd_with_timeout(
1088                 f'ping -c 1 {host} >/dev/null 2>/dev/null',
1089                 timeout_seconds=1.0
1090             )
1091             return x == 0
1092         except Exception:
1093             return False
1094
1095     def thread_pool(self) -> ThreadExecutor:
1096         if self.thread_executor is None:
1097             self.thread_executor = ThreadExecutor()
1098         return self.thread_executor
1099
1100     def process_pool(self) -> ProcessExecutor:
1101         if self.process_executor is None:
1102             self.process_executor = ProcessExecutor()
1103         return self.process_executor
1104
1105     def remote_pool(self) -> RemoteExecutor:
1106         if self.remote_executor is None:
1107             logger.info('Looking for some helper machines...')
1108             pool: List[RemoteWorkerRecord] = []
1109             if self.ping('cheetah.house'):
1110                 logger.info('Found cheetah.house')
1111                 pool.append(
1112                     RemoteWorkerRecord(
1113                         username = 'scott',
1114                         machine = 'cheetah.house',
1115                         weight = 14,
1116                         count = 4,
1117                     ),
1118                 )
1119             if self.ping('video.house'):
1120                 logger.info('Found video.house')
1121                 pool.append(
1122                     RemoteWorkerRecord(
1123                         username = 'scott',
1124                         machine = 'video.house',
1125                         weight = 1,
1126                         count = 4,
1127                     ),
1128                 )
1129             if self.ping('wannabe.house'):
1130                 logger.info('Found wannabe.house')
1131                 pool.append(
1132                     RemoteWorkerRecord(
1133                         username = 'scott',
1134                         machine = 'wannabe.house',
1135                         weight = 2,
1136                         count = 4,
1137                     ),
1138                 )
1139             if self.ping('meerkat.cabin'):
1140                 logger.info('Found meerkat.cabin')
1141                 pool.append(
1142                     RemoteWorkerRecord(
1143                         username = 'scott',
1144                         machine = 'meerkat.cabin',
1145                         weight = 5,
1146                         count = 2,
1147                     ),
1148                 )
1149             if self.ping('kiosk.house'):
1150                 logger.info('Found kiosk.house')
1151                 pool.append(
1152                     RemoteWorkerRecord(
1153                         username = 'pi',
1154                         machine = 'kiosk.house',
1155                         weight = 1,
1156                         count = 2,
1157                     ),
1158                 )
1159             if self.ping('puma.cabin'):
1160                 logger.info('Found puma.cabin')
1161                 pool.append(
1162                     RemoteWorkerRecord(
1163                         username = 'scott',
1164                         machine = 'puma.cabin',
1165                         weight = 12,
1166                         count = 4,
1167                     ),
1168                 )
1169
1170             # The controller machine has a lot to do; go easy on it.
1171             for record in pool:
1172                 if record.machine == platform.node() and record.count > 1:
1173                     logger.info(f'Reducing workload for {record.machine}.')
1174                     record.count = 1
1175
1176             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1177             policy.register_worker_pool(pool)
1178             self.remote_executor = RemoteExecutor(pool, policy)
1179         return self.remote_executor