Clean up the remote executor stuff and create a dedicated heartbeat
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import concurrent.futures as fut
7 from collections import defaultdict
8 from dataclasses import dataclass
9 import logging
10 import numpy
11 import os
12 import platform
13 import random
14 import subprocess
15 import threading
16 import time
17 from typing import Any, Callable, Dict, List, Optional, Set
18 import warnings
19
20 import cloudpickle  # type: ignore
21 from overrides import overrides
22
23 from ansi import bg, fg, underline, reset
24 import argparse_utils
25 import config
26 from decorator_utils import singleton
27 from exec_utils import run_silently, cmd_in_background, cmd_with_timeout
28 import histogram as hist
29 from thread_utils import background_thread
30
31
32 logger = logging.getLogger(__name__)
33
34 parser = config.add_commandline_args(
35     f"Executors ({__file__})",
36     "Args related to processing executors."
37 )
38 parser.add_argument(
39     '--executors_threadpool_size',
40     type=int,
41     metavar='#THREADS',
42     help='Number of threads in the default threadpool, leave unset for default',
43     default=None
44 )
45 parser.add_argument(
46     '--executors_processpool_size',
47     type=int,
48     metavar='#PROCESSES',
49     help='Number of processes in the default processpool, leave unset for default',
50     default=None,
51 )
52 parser.add_argument(
53     '--executors_schedule_remote_backups',
54     default=True,
55     action=argparse_utils.ActionNoYes,
56     help='Should we schedule duplicative backup work if a remote bundle is slow',
57 )
58 parser.add_argument(
59     '--executors_max_bundle_failures',
60     type=int,
61     default=3,
62     metavar='#FAILURES',
63     help='Maximum number of failures before giving up on a bundle',
64 )
65
66 SSH = '/usr/bin/ssh -oForwardX11=no'
67 SCP = '/usr/bin/scp'
68
69
70 def make_cloud_pickle(fun, *args, **kwargs):
71     logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
72     return cloudpickle.dumps((fun, args, kwargs))
73
74
75 class BaseExecutor(ABC):
76     def __init__(self, *, title=''):
77         self.title = title
78         self.task_count = 0
79         self.histogram = hist.SimpleHistogram(
80             hist.SimpleHistogram.n_evenly_spaced_buckets(
81                 int(0), int(500), 50
82             )
83         )
84
85     @abstractmethod
86     def submit(self,
87                function: Callable,
88                *args,
89                **kwargs) -> fut.Future:
90         pass
91
92     @abstractmethod
93     def shutdown(self,
94                  wait: bool = True) -> None:
95         pass
96
97     def adjust_task_count(self, delta: int) -> None:
98         self.task_count += delta
99         logger.debug(f'Executor current task count is {self.task_count}')
100
101
102 class ThreadExecutor(BaseExecutor):
103     def __init__(self,
104                  max_workers: Optional[int] = None):
105         super().__init__()
106         workers = None
107         if max_workers is not None:
108             workers = max_workers
109         elif 'executors_threadpool_size' in config.config:
110             workers = config.config['executors_threadpool_size']
111         logger.debug(f'Creating threadpool executor with {workers} workers')
112         self._thread_pool_executor = fut.ThreadPoolExecutor(
113             max_workers=workers,
114             thread_name_prefix="thread_executor_helper"
115         )
116
117     def run_local_bundle(self, fun, *args, **kwargs):
118         logger.debug(f"Running local bundle at {fun.__name__}")
119         start = time.time()
120         result = fun(*args, **kwargs)
121         end = time.time()
122         self.adjust_task_count(-1)
123         duration = end - start
124         logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
125         self.histogram.add_item(duration)
126         return result
127
128     @overrides
129     def submit(self,
130                function: Callable,
131                *args,
132                **kwargs) -> fut.Future:
133         self.adjust_task_count(+1)
134         newargs = []
135         newargs.append(function)
136         for arg in args:
137             newargs.append(arg)
138         return self._thread_pool_executor.submit(
139             self.run_local_bundle,
140             *newargs,
141             **kwargs)
142
143     @overrides
144     def shutdown(self,
145                  wait = True) -> None:
146         logger.debug(f'Shutting down threadpool executor {self.title}')
147         print(self.histogram)
148         self._thread_pool_executor.shutdown(wait)
149
150
151 class ProcessExecutor(BaseExecutor):
152     def __init__(self,
153                  max_workers=None):
154         super().__init__()
155         workers = None
156         if max_workers is not None:
157             workers = max_workers
158         elif 'executors_processpool_size' in config.config:
159             workers = config.config['executors_processpool_size']
160         logger.debug(f'Creating processpool executor with {workers} workers.')
161         self._process_executor = fut.ProcessPoolExecutor(
162             max_workers=workers,
163         )
164
165     def run_cloud_pickle(self, pickle):
166         fun, args, kwargs = cloudpickle.loads(pickle)
167         logger.debug(f"Running pickled bundle at {fun.__name__}")
168         result = fun(*args, **kwargs)
169         self.adjust_task_count(-1)
170         return result
171
172     @overrides
173     def submit(self,
174                function: Callable,
175                *args,
176                **kwargs) -> fut.Future:
177         start = time.time()
178         self.adjust_task_count(+1)
179         pickle = make_cloud_pickle(function, *args, **kwargs)
180         result = self._process_executor.submit(
181             self.run_cloud_pickle,
182             pickle
183         )
184         result.add_done_callback(
185             lambda _: self.histogram.add_item(
186                 time.time() - start
187             )
188         )
189         return result
190
191     @overrides
192     def shutdown(self, wait=True) -> None:
193         logger.debug(f'Shutting down processpool executor {self.title}')
194         self._process_executor.shutdown(wait)
195         print(self.histogram)
196
197     def __getstate__(self):
198         state = self.__dict__.copy()
199         state['_process_executor'] = None
200         return state
201
202
203 class RemoteExecutorException(Exception):
204     """Thrown when a bundle cannot be executed despite several retries."""
205     pass
206
207
208 @dataclass
209 class RemoteWorkerRecord:
210     username: str
211     machine: str
212     weight: int
213     count: int
214
215     def __hash__(self):
216         return hash((self.username, self.machine))
217
218     def __repr__(self):
219         return f'{self.username}@{self.machine}'
220
221
222 @dataclass
223 class BundleDetails:
224     pickled_code: bytes
225     uuid: str
226     fname: str
227     worker: Optional[RemoteWorkerRecord]
228     username: Optional[str]
229     machine: Optional[str]
230     hostname: str
231     code_file: str
232     result_file: str
233     pid: int
234     start_ts: float
235     end_ts: float
236     slower_than_local_p95: bool
237     slower_than_global_p95: bool
238     src_bundle: BundleDetails
239     is_cancelled: threading.Event
240     was_cancelled: bool
241     backup_bundles: Optional[List[BundleDetails]]
242     failure_count: int
243
244     def __repr__(self):
245         uuid = self.uuid
246         if uuid[-9:-2] == '_backup':
247             uuid = uuid[:-9]
248             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
249         else:
250             suffix = uuid[-6:]
251
252         colorz = [
253             fg('violet red'),
254             fg('red'),
255             fg('orange'),
256             fg('peach orange'),
257             fg('yellow'),
258             fg('marigold yellow'),
259             fg('green yellow'),
260             fg('tea green'),
261             fg('cornflower blue'),
262             fg('turquoise blue'),
263             fg('tropical blue'),
264             fg('lavender purple'),
265             fg('medium purple'),
266         ]
267         c = colorz[int(uuid[-2:], 16) % len(colorz)]
268         fname = self.fname if self.fname is not None else 'nofname'
269         machine = self.machine if self.machine is not None else 'nomachine'
270         return f'{c}{suffix}/{fname}/{machine}{reset()}'
271
272
273 class RemoteExecutorStatus:
274     def __init__(self, total_worker_count: int) -> None:
275         self.worker_count = total_worker_count
276         self.known_workers: Set[RemoteWorkerRecord] = set()
277         self.start_per_bundle: Dict[str, float] = defaultdict(float)
278         self.end_per_bundle: Dict[str, float] = defaultdict(float)
279         self.finished_bundle_timings_per_worker: Dict[
280             RemoteWorkerRecord,
281             List[float]
282         ] = {}
283         self.in_flight_bundles_by_worker: Dict[
284             RemoteWorkerRecord,
285             Set[str]
286         ] = {}
287         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
288         self.finished_bundle_timings: List[float] = []
289         self.last_periodic_dump: Optional[float] = None
290         self.total_bundles_submitted = 0
291
292         # Protects reads and modification using self.  Also used
293         # as a memory fence for modifications to bundle.
294         self.lock = threading.Lock()
295
296     def record_acquire_worker(
297             self,
298             worker: RemoteWorkerRecord,
299             uuid: str
300     ) -> None:
301         with self.lock:
302             self.record_acquire_worker_already_locked(
303                 worker,
304                 uuid
305             )
306
307     def record_acquire_worker_already_locked(
308             self,
309             worker: RemoteWorkerRecord,
310             uuid: str
311     ) -> None:
312         assert self.lock.locked()
313         self.known_workers.add(worker)
314         self.start_per_bundle[uuid] = None
315         x = self.in_flight_bundles_by_worker.get(worker, set())
316         x.add(uuid)
317         self.in_flight_bundles_by_worker[worker] = x
318
319     def record_bundle_details(
320             self,
321             details: BundleDetails) -> None:
322         with self.lock:
323             self.record_bundle_details_already_locked(details)
324
325     def record_bundle_details_already_locked(
326             self,
327             details: BundleDetails) -> None:
328         assert self.lock.locked()
329         self.bundle_details_by_uuid[details.uuid] = details
330
331     def record_release_worker(
332             self,
333             worker: RemoteWorkerRecord,
334             uuid: str,
335             was_cancelled: bool,
336     ) -> None:
337         with self.lock:
338             self.record_release_worker_already_locked(
339                 worker, uuid, was_cancelled
340             )
341
342     def record_release_worker_already_locked(
343             self,
344             worker: RemoteWorkerRecord,
345             uuid: str,
346             was_cancelled: bool,
347     ) -> None:
348         assert self.lock.locked()
349         ts = time.time()
350         self.end_per_bundle[uuid] = ts
351         self.in_flight_bundles_by_worker[worker].remove(uuid)
352         if not was_cancelled:
353             bundle_latency = ts - self.start_per_bundle[uuid]
354             x = self.finished_bundle_timings_per_worker.get(worker, list())
355             x.append(bundle_latency)
356             self.finished_bundle_timings_per_worker[worker] = x
357             self.finished_bundle_timings.append(bundle_latency)
358
359     def record_processing_began(self, uuid: str):
360         with self.lock:
361             self.start_per_bundle[uuid] = time.time()
362
363     def total_in_flight(self) -> int:
364         assert self.lock.locked()
365         total_in_flight = 0
366         for worker in self.known_workers:
367             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
368         return total_in_flight
369
370     def total_idle(self) -> int:
371         assert self.lock.locked()
372         return self.worker_count - self.total_in_flight()
373
374     def __repr__(self):
375         assert self.lock.locked()
376         ts = time.time()
377         total_finished = len(self.finished_bundle_timings)
378         total_in_flight = self.total_in_flight()
379         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
380         qall = None
381         if len(self.finished_bundle_timings) > 1:
382             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
383             ret += (
384                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, '
385                 f'✅={total_finished}/{self.total_bundles_submitted}, '
386                 f'💻n={total_in_flight}/{self.worker_count}\n'
387             )
388         else:
389             ret += (
390                 f' ✅={total_finished}/{self.total_bundles_submitted}, '
391                 f'💻n={total_in_flight}/{self.worker_count}\n'
392             )
393
394         for worker in self.known_workers:
395             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
396             timings = self.finished_bundle_timings_per_worker.get(worker, [])
397             count = len(timings)
398             qworker = None
399             if count > 1:
400                 qworker = numpy.quantile(timings, [0.5, 0.95])
401                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
402             else:
403                 ret += '\n'
404             if count > 0:
405                 ret += f'    ...finished {count} total bundle(s) so far\n'
406             in_flight = len(self.in_flight_bundles_by_worker[worker])
407             if in_flight > 0:
408                 ret += f'    ...{in_flight} bundles currently in flight:\n'
409                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
410                     details = self.bundle_details_by_uuid.get(
411                         bundle_uuid,
412                         None
413                     )
414                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
415                     if self.start_per_bundle[bundle_uuid] is not None:
416                         sec = ts - self.start_per_bundle[bundle_uuid]
417                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
418                     else:
419                         ret += f'       {details} setting up / copying data...'
420                         sec = 0.0
421
422                     if qworker is not None:
423                         if sec > qworker[1]:
424                             ret += f'{bg("red")}>💻p95{reset()} '
425                             if details is not None:
426                                 details.slower_than_local_p95 = True
427                         else:
428                             if details is not None:
429                                 details.slower_than_local_p95 = False
430
431                     if qall is not None:
432                         if sec > qall[1]:
433                             ret += f'{bg("red")}>∀p95{reset()} '
434                             if details is not None:
435                                 details.slower_than_global_p95 = True
436                         else:
437                             details.slower_than_global_p95 = False
438                     ret += '\n'
439         return ret
440
441     def periodic_dump(self, total_bundles_submitted: int) -> None:
442         assert self.lock.locked()
443         self.total_bundles_submitted = total_bundles_submitted
444         ts = time.time()
445         if (
446                 self.last_periodic_dump is None
447                 or ts - self.last_periodic_dump > 5.0
448         ):
449             print(self)
450             self.last_periodic_dump = ts
451
452
453 class RemoteWorkerSelectionPolicy(ABC):
454     def register_worker_pool(self, workers):
455         self.workers = workers
456
457     @abstractmethod
458     def is_worker_available(self) -> bool:
459         pass
460
461     @abstractmethod
462     def acquire_worker(
463             self,
464             machine_to_avoid = None
465     ) -> Optional[RemoteWorkerRecord]:
466         pass
467
468
469 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
470     @overrides
471     def is_worker_available(self) -> bool:
472         for worker in self.workers:
473             if worker.count > 0:
474                 return True
475         return False
476
477     @overrides
478     def acquire_worker(
479             self,
480             machine_to_avoid = None
481     ) -> Optional[RemoteWorkerRecord]:
482         grabbag = []
483         for worker in self.workers:
484             for x in range(0, worker.count):
485                 for y in range(0, worker.weight):
486                     grabbag.append(worker)
487
488         for _ in range(0, 5):
489             random.shuffle(grabbag)
490             worker = grabbag[0]
491             if worker.machine != machine_to_avoid or _ > 2:
492                 if worker.count > 0:
493                     worker.count -= 1
494                     logger.debug(f'Selected worker {worker}')
495                     return worker
496         msg = 'Unexpectedly could not find a worker, retrying...'
497         logger.warning(msg)
498         return None
499
500
501 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
502     def __init__(self) -> None:
503         self.index = 0
504
505     @overrides
506     def is_worker_available(self) -> bool:
507         for worker in self.workers:
508             if worker.count > 0:
509                 return True
510         return False
511
512     @overrides
513     def acquire_worker(
514             self,
515             machine_to_avoid: str = None
516     ) -> Optional[RemoteWorkerRecord]:
517         x = self.index
518         while True:
519             worker = self.workers[x]
520             if worker.count > 0:
521                 worker.count -= 1
522                 x += 1
523                 if x >= len(self.workers):
524                     x = 0
525                 self.index = x
526                 logger.debug(f'Selected worker {worker}')
527                 return worker
528             x += 1
529             if x >= len(self.workers):
530                 x = 0
531             if x == self.index:
532                 msg = 'Unexpectedly could not find a worker, retrying...'
533                 logger.warning(msg)
534                 return None
535
536
537 class RemoteExecutor(BaseExecutor):
538     def __init__(self,
539                  workers: List[RemoteWorkerRecord],
540                  policy: RemoteWorkerSelectionPolicy) -> None:
541         super().__init__()
542         self.workers = workers
543         self.policy = policy
544         self.worker_count = 0
545         for worker in self.workers:
546             self.worker_count += worker.count
547         if self.worker_count <= 0:
548             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
549             logger.critical(msg)
550             raise RemoteExecutorException(msg)
551         self.policy.register_worker_pool(self.workers)
552         self.cv = threading.Condition()
553         logger.debug(f'Creating {self.worker_count} local threads, one per remote worker.')
554         self._helper_executor = fut.ThreadPoolExecutor(
555             thread_name_prefix="remote_executor_helper",
556             max_workers=self.worker_count,
557         )
558         self.status = RemoteExecutorStatus(self.worker_count)
559         self.total_bundles_submitted = 0
560         self.backup_lock = threading.Lock()
561         self.last_backup = None
562         (self.heartbeat_thread, self.heartbeat_stop_event) = self.run_periodic_heartbeat()
563
564     @background_thread
565     def run_periodic_heartbeat(self, stop_event) -> None:
566         while not stop_event.is_set():
567             time.sleep(5.0)
568             logger.debug('Running periodic heartbeat code...')
569             self.heartbeat()
570         logger.debug('Periodic heartbeat thread shutting down.')
571
572     def heartbeat(self) -> None:
573         with self.status.lock:
574             # Dump regular progress report
575             self.status.periodic_dump(self.total_bundles_submitted)
576
577             # Look for bundles to reschedule via executor.submit
578             if config.config['executors_schedule_remote_backups']:
579                 self.maybe_schedule_backup_bundles()
580
581     def maybe_schedule_backup_bundles(self):
582         assert self.status.lock.locked()
583         num_done = len(self.status.finished_bundle_timings)
584         num_idle_workers = self.worker_count - self.task_count
585         now = time.time()
586         if (
587                 num_done > 2
588                 and num_idle_workers > 1
589                 and (self.last_backup is None or (now - self.last_backup > 6.0))
590                 and self.backup_lock.acquire(blocking=False)
591         ):
592             try:
593                 assert self.backup_lock.locked()
594
595                 bundle_to_backup = None
596                 best_score = None
597                 for worker, bundle_uuids in self.status.in_flight_bundles_by_worker.items():
598
599                     # Prefer to schedule backups of bundles running on
600                     # slower machines.
601                     base_score = 0
602                     for record in self.workers:
603                         if worker.machine == record.machine:
604                             base_score = float(record.weight)
605                             base_score = 1.0 / base_score
606                             base_score *= 200.0
607                             base_score = int(base_score)
608                             break
609
610                     for uuid in bundle_uuids:
611                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
612                         if (
613                                 bundle is not None
614                                 and bundle.src_bundle is None
615                                 and bundle.backup_bundles is not None
616                         ):
617                             score = base_score
618
619                             # Schedule backups of bundles running
620                             # longer; especially those that are
621                             # unexpectedly slow.
622                             start_ts = self.status.start_per_bundle[uuid]
623                             if start_ts is not None:
624                                 runtime = now - start_ts
625                                 score += runtime
626                                 logger.debug(f'score[{bundle}] => {score}  # latency boost')
627
628                                 if bundle.slower_than_local_p95:
629                                     score += runtime / 2
630                                     logger.debug(f'score[{bundle}] => {score}  # >worker p95')
631
632                                 if bundle.slower_than_global_p95:
633                                     score += runtime / 4
634                                     logger.debug(f'score[{bundle}] => {score}  # >global p95')
635
636                             # Prefer backups of bundles that don't
637                             # have backups already.
638                             backup_count = len(bundle.backup_bundles)
639                             if backup_count == 0:
640                                 score *= 2
641                             elif backup_count == 1:
642                                 score /= 2
643                             elif backup_count == 2:
644                                 score /= 8
645                             else:
646                                 score = 0
647                             logger.debug(
648                                 f'score[{bundle}] => {score}  # {backup_count} dup backup factor'
649                             )
650
651                             if (
652                                     score != 0
653                                     and (best_score is None or score > best_score)
654                             ):
655                                 bundle_to_backup = bundle
656                                 assert bundle is not None
657                                 assert bundle.backup_bundles is not None
658                                 assert bundle.src_bundle is None
659                                 best_score = score
660
661                 # Note: this is all still happening on the heartbeat
662                 # runner thread.  That's ok because
663                 # schedule_backup_for_bundle uses the executor to
664                 # submit the bundle again which will cause it to be
665                 # picked up by a worker thread and allow this thread
666                 # to return to run future heartbeats.
667                 if bundle_to_backup is not None:
668                     self.last_backup = now
669                     logger.info(
670                         f'=====> SCHEDULING BACKUP {bundle_to_backup} (score={best_score:.1f}) <====='
671                     )
672                     self.schedule_backup_for_bundle(bundle_to_backup)
673             finally:
674                 self.backup_lock.release()
675
676     def is_worker_available(self) -> bool:
677         return self.policy.is_worker_available()
678
679     def acquire_worker(
680             self,
681             machine_to_avoid: str = None
682     ) -> Optional[RemoteWorkerRecord]:
683         return self.policy.acquire_worker(machine_to_avoid)
684
685     def find_available_worker_or_block(
686             self,
687             machine_to_avoid: str = None
688     ) -> RemoteWorkerRecord:
689         with self.cv:
690             while not self.is_worker_available():
691                 self.cv.wait()
692             worker = self.acquire_worker(machine_to_avoid)
693             if worker is not None:
694                 return worker
695         msg = "We should never reach this point in the code"
696         logger.critical(msg)
697         raise Exception(msg)
698
699     def release_worker(
700             self,
701             bundle: BundleDetails,
702             *,
703             was_cancelled=True
704     ) -> None:
705         worker = bundle.worker
706         assert worker is not None
707         logger.debug(f'Released worker {worker}')
708         self.status.record_release_worker(
709             worker,
710             bundle.uuid,
711             was_cancelled,
712         )
713         with self.cv:
714             worker.count += 1
715             self.cv.notify()
716         self.adjust_task_count(-1)
717
718     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
719         with self.status.lock:
720             if bundle.is_cancelled.wait(timeout=0.0):
721                 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
722                 bundle.was_cancelled = True
723                 return True
724         return False
725
726     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
727         """Find a worker for bundle or block until one is available."""
728         self.adjust_task_count(+1)
729         uuid = bundle.uuid
730         hostname = bundle.hostname
731         avoid_machine = override_avoid_machine
732         is_original = bundle.src_bundle is None
733
734         # Try not to schedule a backup on the same host as the original.
735         if avoid_machine is None and bundle.src_bundle is not None:
736             avoid_machine = bundle.src_bundle.machine
737         worker = None
738         while worker is None:
739             worker = self.find_available_worker_or_block(avoid_machine)
740         assert worker
741
742         # Ok, found a worker.
743         bundle.worker = worker
744         machine = bundle.machine = worker.machine
745         username = bundle.username = worker.username
746         self.status.record_acquire_worker(worker, uuid)
747         logger.debug(f'{bundle}: Running bundle on {worker}...')
748
749         # Before we do any work, make sure the bundle is still viable.
750         # It may have been some time between when it was submitted and
751         # now due to lack of worker availability and someone else may
752         # have already finished it.
753         if self.check_if_cancelled(bundle):
754             try:
755                 return self.process_work_result(bundle)
756             except Exception as e:
757                 logger.warning(
758                     f'{bundle}: bundle says it\'s cancelled upfront but no results?!'
759                 )
760                 self.release_worker(bundle)
761                 if is_original:
762                     # Weird.  We are the original owner of this
763                     # bundle.  For it to have been cancelled, a backup
764                     # must have already started and completed before
765                     # we even for started.  Moreover, the backup says
766                     # it is done but we can't find the results it
767                     # should have copied over.  Reschedule the whole
768                     # thing.
769                     logger.exception(e)
770                     logger.error(
771                         f'{bundle}: We are the original owner thread and yet there are ' +
772                         'no results for this bundle.  This is unexpected and bad.'
773                     )
774                     return self.emergency_retry_nasty_bundle(bundle)
775                 else:
776                     # Expected(?).  We're a backup and our bundle is
777                     # cancelled before we even got started.  Something
778                     # went bad in process_work_result (I acutually don't
779                     # see what?) but probably not worth worrying
780                     # about.  Let the original thread worry about
781                     # either finding the results or complaining about
782                     # it.
783                     return None
784
785         # Send input code / data to worker machine if it's not local.
786         if hostname not in machine:
787             try:
788                 cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
789                 start_ts = time.time()
790                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
791                 run_silently(cmd)
792                 xfer_latency = time.time() - start_ts
793                 logger.info(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.")
794             except Exception as e:
795                 self.release_worker(bundle)
796                 if is_original:
797                     # Weird.  We tried to copy the code to the worker and it failed...
798                     # And we're the original bundle.  We have to retry.
799                     logger.exception(e)
800                     logger.error(
801                         f"{bundle}: Failed to send instructions to the worker machine?! " +
802                         "This is not expected; we\'re the original bundle so this shouldn\'t " +
803                         "be a race condition.  Attempting an emergency retry..."
804                     )
805                     return self.emergency_retry_nasty_bundle(bundle)
806                 else:
807                     # This is actually expected; we're a backup.
808                     # There's a race condition where someone else
809                     # already finished the work and removed the source
810                     # code file before we could copy it.  No biggie.
811                     msg = f'{bundle}: Failed to send instructions to the worker machine... '
812                     msg += 'We\'re a backup and this may be caused by the original (or some '
813                     msg += 'other backup) already finishing this work.  Ignoring this.'
814                     logger.warning(msg)
815                     return None
816
817         # Kick off the work.  Note that if this fails we let
818         # wait_for_process deal with it.
819         self.status.record_processing_began(uuid)
820         cmd = (f'{SSH} {bundle.username}@{bundle.machine} '
821                f'"source py38-venv/bin/activate &&'
822                f' /home/scott/lib/python_modules/remote_worker.py'
823                f' --code_file {bundle.code_file} --result_file {bundle.result_file}"')
824         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
825         p = cmd_in_background(cmd, silent=True)
826         bundle.pid = p.pid
827         logger.debug(f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.')
828         return self.wait_for_process(p, bundle, 0)
829
830     def wait_for_process(self, p: subprocess.Popen, bundle: BundleDetails, depth: int) -> Any:
831         machine = bundle.machine
832         pid = p.pid
833         if depth > 3:
834             logger.error(
835                 f"I've gotten repeated errors waiting on this bundle; giving up on pid={pid}."
836             )
837             p.terminate()
838             self.release_worker(bundle)
839             return self.emergency_retry_nasty_bundle(bundle)
840
841         # Spin until either the ssh job we scheduled finishes the
842         # bundle or some backup worker signals that they finished it
843         # before we could.
844         while True:
845             try:
846                 p.wait(timeout=0.25)
847             except subprocess.TimeoutExpired:
848                 if self.check_if_cancelled(bundle):
849                     logger.info(
850                         f'{bundle}: another worker finished bundle, checking it out...'
851                     )
852                     break
853             else:
854                 logger.info(
855                     f"{bundle}: pid {pid} ({machine}) is finished!"
856                 )
857                 p = None
858                 break
859
860         # If we get here we believe the bundle is done; either the ssh
861         # subprocess finished (hopefully successfully) or we noticed
862         # that some other worker seems to have completed the bundle
863         # and we're bailing out.
864         try:
865             ret = self.process_work_result(bundle)
866             if ret is not None and p is not None:
867                 p.terminate()
868             return ret
869
870         # Something went wrong; e.g. we could not copy the results
871         # back, cleanup after ourselves on the remote machine, or
872         # unpickle the results we got from the remove machine.  If we
873         # still have an active ssh subprocess, keep waiting on it.
874         # Otherwise, time for an emergency reschedule.
875         except Exception as e:
876             logger.exception(e)
877             logger.error(f'{bundle}: Something unexpected just happened...')
878             if p is not None:
879                 msg = f"{bundle}: Failed to wrap up \"done\" bundle, re-waiting on active ssh."
880                 logger.warning(msg)
881                 return self.wait_for_process(p, bundle, depth + 1)
882             else:
883                 self.release_worker(bundle)
884                 return self.emergency_retry_nasty_bundle(bundle)
885
886     def process_work_result(self, bundle: BundleDetails) -> Any:
887         with self.status.lock:
888             is_original = bundle.src_bundle is None
889             was_cancelled = bundle.was_cancelled
890             username = bundle.username
891             machine = bundle.machine
892             result_file = bundle.result_file
893             code_file = bundle.code_file
894
895             # Whether original or backup, if we finished first we must
896             # fetch the results if the computation happened on a
897             # remote machine.
898             bundle.end_ts = time.time()
899             if not was_cancelled:
900                 assert bundle.machine is not None
901                 if bundle.hostname not in bundle.machine:
902                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
903                     logger.info(
904                         f"{bundle}: Fetching results back from {username}@{machine} via {cmd}"
905                     )
906
907                     # If either of these throw they are handled in
908                     # wait_for_process.
909                     attempts = 0
910                     while True:
911                         try:
912                             run_silently(cmd)
913                         except Exception as e:
914                             attempts += 1
915                             if attempts >= 3:
916                                 raise(e)
917                         else:
918                             break
919
920                     run_silently(f'{SSH} {username}@{machine}'
921                                  f' "/bin/rm -f {code_file} {result_file}"')
922                     logger.debug(f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.')
923                 dur = bundle.end_ts - bundle.start_ts
924                 self.histogram.add_item(dur)
925
926         # Only the original worker should unpickle the file contents
927         # though since it's the only one whose result matters.  The
928         # original is also the only job that may delete result_file
929         # from disk.  Note that the original may have been cancelled
930         # if one of the backups finished first; it still must read the
931         # result from disk.
932         if is_original:
933             logger.debug(f"{bundle}: Unpickling {result_file}.")
934             try:
935                 with open(result_file, 'rb') as rb:
936                     serialized = rb.read()
937                 result = cloudpickle.loads(serialized)
938             except Exception as e:
939                 logger.exception(e)
940                 msg = f'Failed to load {result_file}... this is bad news.'
941                 logger.critical(msg)
942                 self.release_worker(bundle)
943
944                 # Re-raise the exception; the code in wait_for_process may
945                 # decide to emergency_retry_nasty_bundle here.
946                 raise Exception(e)
947             logger.debug(
948                 f'Removing local (master) {code_file} and {result_file}.'
949             )
950             os.remove(f'{result_file}')
951             os.remove(f'{code_file}')
952
953             # Notify any backups that the original is done so they
954             # should stop ASAP.  Do this whether or not we
955             # finished first since there could be more than one
956             # backup.
957             if bundle.backup_bundles is not None:
958                 for backup in bundle.backup_bundles:
959                     logger.debug(
960                         f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled'
961                     )
962                     backup.is_cancelled.set()
963
964         # This is a backup job and, by now, we have already fetched
965         # the bundle results.
966         else:
967             # Backup results don't matter, they just need to leave the
968             # result file in the right place for their originals to
969             # read/unpickle later.
970             result = None
971
972             # Tell the original to stop if we finished first.
973             if not was_cancelled:
974                 logger.debug(
975                     f'{bundle}: Notifying original {bundle.src_bundle.uuid} we beat them to it.'
976                 )
977                 bundle.src_bundle.is_cancelled.set()
978         self.release_worker(bundle, was_cancelled=was_cancelled)
979         return result
980
981     def create_original_bundle(self, pickle, fname: str):
982         from string_utils import generate_uuid
983         uuid = generate_uuid(omit_dashes=True)
984         code_file = f'/tmp/{uuid}.code.bin'
985         result_file = f'/tmp/{uuid}.result.bin'
986
987         logger.debug(f'Writing pickled code to {code_file}')
988         with open(f'{code_file}', 'wb') as wb:
989             wb.write(pickle)
990
991         bundle = BundleDetails(
992             pickled_code = pickle,
993             uuid = uuid,
994             fname = fname,
995             worker = None,
996             username = None,
997             machine = None,
998             hostname = platform.node(),
999             code_file = code_file,
1000             result_file = result_file,
1001             pid = 0,
1002             start_ts = time.time(),
1003             end_ts = 0.0,
1004             slower_than_local_p95 = False,
1005             slower_than_global_p95 = False,
1006             src_bundle = None,
1007             is_cancelled = threading.Event(),
1008             was_cancelled = False,
1009             backup_bundles = [],
1010             failure_count = 0,
1011         )
1012         self.status.record_bundle_details(bundle)
1013         logger.debug(f'{bundle}: Created an original bundle')
1014         return bundle
1015
1016     def create_backup_bundle(self, src_bundle: BundleDetails):
1017         assert src_bundle.backup_bundles is not None
1018         n = len(src_bundle.backup_bundles)
1019         uuid = src_bundle.uuid + f'_backup#{n}'
1020
1021         backup_bundle = BundleDetails(
1022             pickled_code = src_bundle.pickled_code,
1023             uuid = uuid,
1024             fname = src_bundle.fname,
1025             worker = None,
1026             username = None,
1027             machine = None,
1028             hostname = src_bundle.hostname,
1029             code_file = src_bundle.code_file,
1030             result_file = src_bundle.result_file,
1031             pid = 0,
1032             start_ts = time.time(),
1033             end_ts = 0.0,
1034             slower_than_local_p95 = False,
1035             slower_than_global_p95 = False,
1036             src_bundle = src_bundle,
1037             is_cancelled = threading.Event(),
1038             was_cancelled = False,
1039             backup_bundles = None,    # backup backups not allowed
1040             failure_count = 0,
1041         )
1042         src_bundle.backup_bundles.append(backup_bundle)
1043         self.status.record_bundle_details_already_locked(backup_bundle)
1044         logger.debug(f'{backup_bundle}: Created a backup bundle')
1045         return backup_bundle
1046
1047     def schedule_backup_for_bundle(self,
1048                                    src_bundle: BundleDetails):
1049         assert self.status.lock.locked()
1050         assert src_bundle is not None
1051         backup_bundle = self.create_backup_bundle(src_bundle)
1052         logger.debug(
1053             f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
1054         )
1055         self._helper_executor.submit(self.launch, backup_bundle)
1056
1057         # Results from backups don't matter; if they finish first
1058         # they will move the result_file to this machine and let
1059         # the original pick them up and unpickle them.
1060
1061     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> fut.Future:
1062         is_original = bundle.src_bundle is None
1063         bundle.worker = None
1064         avoid_last_machine = bundle.machine
1065         bundle.machine = None
1066         bundle.username = None
1067         bundle.failure_count += 1
1068         if is_original:
1069             retry_limit = 3
1070         else:
1071             retry_limit = 2
1072
1073         if bundle.failure_count > retry_limit:
1074             logger.error(
1075                 f'{bundle}: Tried this bundle too many times already ({retry_limit}x); giving up.'
1076             )
1077             if is_original:
1078                 raise RemoteExecutorException(
1079                     f'{bundle}: This bundle can\'t be completed despite several backups and retries'
1080                 )
1081             else:
1082                 logger.error(f'{bundle}: At least it\'s only a backup; better luck with the others.')
1083             return None
1084         else:
1085             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1086             logger.warning(msg)
1087             warnings.warn(msg)
1088             return self.launch(bundle, avoid_last_machine)
1089
1090     @overrides
1091     def submit(self,
1092                function: Callable,
1093                *args,
1094                **kwargs) -> fut.Future:
1095         pickle = make_cloud_pickle(function, *args, **kwargs)
1096         bundle = self.create_original_bundle(pickle, function.__name__)
1097         self.total_bundles_submitted += 1
1098         return self._helper_executor.submit(self.launch, bundle)
1099
1100     @overrides
1101     def shutdown(self, wait=True) -> None:
1102         logging.debug(f'Shutting down RemoteExecutor {self.title}')
1103         self.heartbeat_stop_event.set()
1104         self.heartbeat_thread.join()
1105         self._helper_executor.shutdown(wait)
1106         print(self.histogram)
1107
1108
1109 @singleton
1110 class DefaultExecutors(object):
1111     def __init__(self):
1112         self.thread_executor: Optional[ThreadExecutor] = None
1113         self.process_executor: Optional[ProcessExecutor] = None
1114         self.remote_executor: Optional[RemoteExecutor] = None
1115
1116     def ping(self, host) -> bool:
1117         logger.debug(f'RUN> ping -c 1 {host}')
1118         try:
1119             x = cmd_with_timeout(
1120                 f'ping -c 1 {host} >/dev/null 2>/dev/null',
1121                 timeout_seconds=1.0
1122             )
1123             return x == 0
1124         except Exception:
1125             return False
1126
1127     def thread_pool(self) -> ThreadExecutor:
1128         if self.thread_executor is None:
1129             self.thread_executor = ThreadExecutor()
1130         return self.thread_executor
1131
1132     def process_pool(self) -> ProcessExecutor:
1133         if self.process_executor is None:
1134             self.process_executor = ProcessExecutor()
1135         return self.process_executor
1136
1137     def remote_pool(self) -> RemoteExecutor:
1138         if self.remote_executor is None:
1139             logger.info('Looking for some helper machines...')
1140             pool: List[RemoteWorkerRecord] = []
1141             if self.ping('cheetah.house'):
1142                 logger.info('Found cheetah.house')
1143                 pool.append(
1144                     RemoteWorkerRecord(
1145                         username = 'scott',
1146                         machine = 'cheetah.house',
1147                         weight = 25,
1148                         count = 6,
1149                     ),
1150                 )
1151             if self.ping('meerkat.cabin'):
1152                 logger.info('Found meerkat.cabin')
1153                 pool.append(
1154                     RemoteWorkerRecord(
1155                         username = 'scott',
1156                         machine = 'meerkat.cabin',
1157                         weight = 12,
1158                         count = 2,
1159                     ),
1160                 )
1161             if self.ping('wannabe.house'):
1162                 logger.info('Found wannabe.house')
1163                 pool.append(
1164                     RemoteWorkerRecord(
1165                         username = 'scott',
1166                         machine = 'wannabe.house',
1167                         weight = 30,
1168                         count = 10,
1169                     ),
1170                 )
1171             if self.ping('puma.cabin'):
1172                 logger.info('Found puma.cabin')
1173                 pool.append(
1174                     RemoteWorkerRecord(
1175                         username = 'scott',
1176                         machine = 'puma.cabin',
1177                         weight = 25,
1178                         count = 6,
1179                     ),
1180                 )
1181             if self.ping('backup.house'):
1182                 logger.info('Found backup.house')
1183                 pool.append(
1184                     RemoteWorkerRecord(
1185                         username = 'scott',
1186                         machine = 'backup.house',
1187                         weight = 7,
1188                         count = 2,
1189                     ),
1190                 )
1191
1192             # The controller machine has a lot to do; go easy on it.
1193             for record in pool:
1194                 if record.machine == platform.node() and record.count > 1:
1195                     logger.info(f'Reducing workload for {record.machine}.')
1196                     record.count = 1
1197
1198             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1199             policy.register_worker_pool(pool)
1200             self.remote_executor = RemoteExecutor(pool, policy)
1201         return self.remote_executor
1202
1203     def shutdown(self) -> None:
1204         if self.thread_executor is not None:
1205             self.thread_executor.shutdown()
1206             self.thread_executor = None
1207         if self.process_executor is not None:
1208             self.process_executor.shutdown()
1209             self.process_executor = None
1210         if self.remote_executor is not None:
1211             self.remote_executor.shutdown()
1212             self.remote_executor = None