Add coding comments for files with utf8 characters in there.
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3
4 from __future__ import annotations
5
6 import concurrent.futures as fut
7 import logging
8 import os
9 import platform
10 import random
11 import subprocess
12 import threading
13 import time
14 import warnings
15 from abc import ABC, abstractmethod
16 from collections import defaultdict
17 from dataclasses import dataclass
18 from typing import Any, Callable, Dict, List, Optional, Set
19
20 import cloudpickle  # type: ignore
21 import numpy
22 from overrides import overrides
23
24 import argparse_utils
25 import config
26 import histogram as hist
27 from ansi import bg, fg, reset, underline
28 from decorator_utils import singleton
29 from exec_utils import cmd_in_background, cmd_with_timeout, run_silently
30 from thread_utils import background_thread
31
32 logger = logging.getLogger(__name__)
33
34 parser = config.add_commandline_args(
35     f"Executors ({__file__})", "Args related to processing executors."
36 )
37 parser.add_argument(
38     '--executors_threadpool_size',
39     type=int,
40     metavar='#THREADS',
41     help='Number of threads in the default threadpool, leave unset for default',
42     default=None,
43 )
44 parser.add_argument(
45     '--executors_processpool_size',
46     type=int,
47     metavar='#PROCESSES',
48     help='Number of processes in the default processpool, leave unset for default',
49     default=None,
50 )
51 parser.add_argument(
52     '--executors_schedule_remote_backups',
53     default=True,
54     action=argparse_utils.ActionNoYes,
55     help='Should we schedule duplicative backup work if a remote bundle is slow',
56 )
57 parser.add_argument(
58     '--executors_max_bundle_failures',
59     type=int,
60     default=3,
61     metavar='#FAILURES',
62     help='Maximum number of failures before giving up on a bundle',
63 )
64
65 SSH = '/usr/bin/ssh -oForwardX11=no'
66 SCP = '/usr/bin/scp -C'
67
68
69 def make_cloud_pickle(fun, *args, **kwargs):
70     logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
71     return cloudpickle.dumps((fun, args, kwargs))
72
73
74 class BaseExecutor(ABC):
75     def __init__(self, *, title=''):
76         self.title = title
77         self.histogram = hist.SimpleHistogram(
78             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
79         )
80         self.task_count = 0
81
82     @abstractmethod
83     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
84         pass
85
86     @abstractmethod
87     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
88         pass
89
90     def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
91         """Shutdown the executor and return True if the executor is idle
92         (i.e. there are no pending or active tasks).  Return False
93         otherwise.  Note: this should only be called by the launcher
94         process.
95
96         """
97         if self.task_count == 0:
98             self.shutdown(wait=True, quiet=quiet)
99             return True
100         return False
101
102     def adjust_task_count(self, delta: int) -> None:
103         """Change the task count.  Note: do not call this method from a
104         worker, it should only be called by the launcher process /
105         thread / machine.
106
107         """
108         self.task_count += delta
109         logger.debug(f'Adjusted task count by {delta} to {self.task_count}')
110
111     def get_task_count(self) -> int:
112         """Change the task count.  Note: do not call this method from a
113         worker, it should only be called by the launcher process /
114         thread / machine.
115
116         """
117         return self.task_count
118
119
120 class ThreadExecutor(BaseExecutor):
121     def __init__(self, max_workers: Optional[int] = None):
122         super().__init__()
123         workers = None
124         if max_workers is not None:
125             workers = max_workers
126         elif 'executors_threadpool_size' in config.config:
127             workers = config.config['executors_threadpool_size']
128         logger.debug(f'Creating threadpool executor with {workers} workers')
129         self._thread_pool_executor = fut.ThreadPoolExecutor(
130             max_workers=workers, thread_name_prefix="thread_executor_helper"
131         )
132         self.already_shutdown = False
133
134     # This is run on a different thread; do not adjust task count here.
135     def run_local_bundle(self, fun, *args, **kwargs):
136         logger.debug(f"Running local bundle at {fun.__name__}")
137         result = fun(*args, **kwargs)
138         return result
139
140     @overrides
141     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
142         if self.already_shutdown:
143             raise Exception('Submitted work after shutdown.')
144         self.adjust_task_count(+1)
145         newargs = []
146         newargs.append(function)
147         for arg in args:
148             newargs.append(arg)
149         start = time.time()
150         result = self._thread_pool_executor.submit(self.run_local_bundle, *newargs, **kwargs)
151         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
152         result.add_done_callback(lambda _: self.adjust_task_count(-1))
153         return result
154
155     @overrides
156     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
157         if not self.already_shutdown:
158             logger.debug(f'Shutting down threadpool executor {self.title}')
159             self._thread_pool_executor.shutdown(wait)
160             if not quiet:
161                 print(self.histogram.__repr__(label_formatter='%ds'))
162             self.already_shutdown = True
163
164
165 class ProcessExecutor(BaseExecutor):
166     def __init__(self, max_workers=None):
167         super().__init__()
168         workers = None
169         if max_workers is not None:
170             workers = max_workers
171         elif 'executors_processpool_size' in config.config:
172             workers = config.config['executors_processpool_size']
173         logger.debug(f'Creating processpool executor with {workers} workers.')
174         self._process_executor = fut.ProcessPoolExecutor(
175             max_workers=workers,
176         )
177         self.already_shutdown = False
178
179     # This is run in another process; do not adjust task count here.
180     def run_cloud_pickle(self, pickle):
181         fun, args, kwargs = cloudpickle.loads(pickle)
182         logger.debug(f"Running pickled bundle at {fun.__name__}")
183         result = fun(*args, **kwargs)
184         return result
185
186     @overrides
187     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
188         if self.already_shutdown:
189             raise Exception('Submitted work after shutdown.')
190         start = time.time()
191         self.adjust_task_count(+1)
192         pickle = make_cloud_pickle(function, *args, **kwargs)
193         result = self._process_executor.submit(self.run_cloud_pickle, pickle)
194         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
195         result.add_done_callback(lambda _: self.adjust_task_count(-1))
196         return result
197
198     @overrides
199     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
200         if not self.already_shutdown:
201             logger.debug(f'Shutting down processpool executor {self.title}')
202             self._process_executor.shutdown(wait)
203             if not quiet:
204                 print(self.histogram.__repr__(label_formatter='%ds'))
205             self.already_shutdown = True
206
207     def __getstate__(self):
208         state = self.__dict__.copy()
209         state['_process_executor'] = None
210         return state
211
212
213 class RemoteExecutorException(Exception):
214     """Thrown when a bundle cannot be executed despite several retries."""
215
216     pass
217
218
219 @dataclass
220 class RemoteWorkerRecord:
221     username: str
222     machine: str
223     weight: int
224     count: int
225
226     def __hash__(self):
227         return hash((self.username, self.machine))
228
229     def __repr__(self):
230         return f'{self.username}@{self.machine}'
231
232
233 @dataclass
234 class BundleDetails:
235     pickled_code: bytes
236     uuid: str
237     fname: str
238     worker: Optional[RemoteWorkerRecord]
239     username: Optional[str]
240     machine: Optional[str]
241     hostname: str
242     code_file: str
243     result_file: str
244     pid: int
245     start_ts: float
246     end_ts: float
247     slower_than_local_p95: bool
248     slower_than_global_p95: bool
249     src_bundle: Optional[BundleDetails]
250     is_cancelled: threading.Event
251     was_cancelled: bool
252     backup_bundles: Optional[List[BundleDetails]]
253     failure_count: int
254
255     def __repr__(self):
256         uuid = self.uuid
257         if uuid[-9:-2] == '_backup':
258             uuid = uuid[:-9]
259             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
260         else:
261             suffix = uuid[-6:]
262
263         colorz = [
264             fg('violet red'),
265             fg('red'),
266             fg('orange'),
267             fg('peach orange'),
268             fg('yellow'),
269             fg('marigold yellow'),
270             fg('green yellow'),
271             fg('tea green'),
272             fg('cornflower blue'),
273             fg('turquoise blue'),
274             fg('tropical blue'),
275             fg('lavender purple'),
276             fg('medium purple'),
277         ]
278         c = colorz[int(uuid[-2:], 16) % len(colorz)]
279         fname = self.fname if self.fname is not None else 'nofname'
280         machine = self.machine if self.machine is not None else 'nomachine'
281         return f'{c}{suffix}/{fname}/{machine}{reset()}'
282
283
284 class RemoteExecutorStatus:
285     def __init__(self, total_worker_count: int) -> None:
286         self.worker_count: int = total_worker_count
287         self.known_workers: Set[RemoteWorkerRecord] = set()
288         self.start_time: float = time.time()
289         self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
290         self.end_per_bundle: Dict[str, float] = defaultdict(float)
291         self.finished_bundle_timings_per_worker: Dict[RemoteWorkerRecord, List[float]] = {}
292         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
293         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
294         self.finished_bundle_timings: List[float] = []
295         self.last_periodic_dump: Optional[float] = None
296         self.total_bundles_submitted: int = 0
297
298         # Protects reads and modification using self.  Also used
299         # as a memory fence for modifications to bundle.
300         self.lock: threading.Lock = threading.Lock()
301
302     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
303         with self.lock:
304             self.record_acquire_worker_already_locked(worker, uuid)
305
306     def record_acquire_worker_already_locked(self, worker: RemoteWorkerRecord, uuid: str) -> None:
307         assert self.lock.locked()
308         self.known_workers.add(worker)
309         self.start_per_bundle[uuid] = None
310         x = self.in_flight_bundles_by_worker.get(worker, set())
311         x.add(uuid)
312         self.in_flight_bundles_by_worker[worker] = x
313
314     def record_bundle_details(self, details: BundleDetails) -> None:
315         with self.lock:
316             self.record_bundle_details_already_locked(details)
317
318     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
319         assert self.lock.locked()
320         self.bundle_details_by_uuid[details.uuid] = details
321
322     def record_release_worker(
323         self,
324         worker: RemoteWorkerRecord,
325         uuid: str,
326         was_cancelled: bool,
327     ) -> None:
328         with self.lock:
329             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
330
331     def record_release_worker_already_locked(
332         self,
333         worker: RemoteWorkerRecord,
334         uuid: str,
335         was_cancelled: bool,
336     ) -> None:
337         assert self.lock.locked()
338         ts = time.time()
339         self.end_per_bundle[uuid] = ts
340         self.in_flight_bundles_by_worker[worker].remove(uuid)
341         if not was_cancelled:
342             start = self.start_per_bundle[uuid]
343             assert start is not None
344             bundle_latency = ts - start
345             x = self.finished_bundle_timings_per_worker.get(worker, list())
346             x.append(bundle_latency)
347             self.finished_bundle_timings_per_worker[worker] = x
348             self.finished_bundle_timings.append(bundle_latency)
349
350     def record_processing_began(self, uuid: str):
351         with self.lock:
352             self.start_per_bundle[uuid] = time.time()
353
354     def total_in_flight(self) -> int:
355         assert self.lock.locked()
356         total_in_flight = 0
357         for worker in self.known_workers:
358             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
359         return total_in_flight
360
361     def total_idle(self) -> int:
362         assert self.lock.locked()
363         return self.worker_count - self.total_in_flight()
364
365     def __repr__(self):
366         assert self.lock.locked()
367         ts = time.time()
368         total_finished = len(self.finished_bundle_timings)
369         total_in_flight = self.total_in_flight()
370         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
371         qall = None
372         if len(self.finished_bundle_timings) > 1:
373             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
374             ret += (
375                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, total={ts-self.start_time:.1f}s, '
376                 f'✅={total_finished}/{self.total_bundles_submitted}, '
377                 f'💻n={total_in_flight}/{self.worker_count}\n'
378             )
379         else:
380             ret += (
381                 f'⏱={ts-self.start_time:.1f}s, '
382                 f'✅={total_finished}/{self.total_bundles_submitted}, '
383                 f'💻n={total_in_flight}/{self.worker_count}\n'
384             )
385
386         for worker in self.known_workers:
387             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
388             timings = self.finished_bundle_timings_per_worker.get(worker, [])
389             count = len(timings)
390             qworker = None
391             if count > 1:
392                 qworker = numpy.quantile(timings, [0.5, 0.95])
393                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
394             else:
395                 ret += '\n'
396             if count > 0:
397                 ret += f'    ...finished {count} total bundle(s) so far\n'
398             in_flight = len(self.in_flight_bundles_by_worker[worker])
399             if in_flight > 0:
400                 ret += f'    ...{in_flight} bundles currently in flight:\n'
401                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
402                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
403                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
404                     if self.start_per_bundle[bundle_uuid] is not None:
405                         sec = ts - self.start_per_bundle[bundle_uuid]
406                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
407                     else:
408                         ret += f'       {details} setting up / copying data...'
409                         sec = 0.0
410
411                     if qworker is not None:
412                         if sec > qworker[1]:
413                             ret += f'{bg("red")}>💻p95{reset()} '
414                             if details is not None:
415                                 details.slower_than_local_p95 = True
416                         else:
417                             if details is not None:
418                                 details.slower_than_local_p95 = False
419
420                     if qall is not None:
421                         if sec > qall[1]:
422                             ret += f'{bg("red")}>∀p95{reset()} '
423                             if details is not None:
424                                 details.slower_than_global_p95 = True
425                         else:
426                             details.slower_than_global_p95 = False
427                     ret += '\n'
428         return ret
429
430     def periodic_dump(self, total_bundles_submitted: int) -> None:
431         assert self.lock.locked()
432         self.total_bundles_submitted = total_bundles_submitted
433         ts = time.time()
434         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
435             print(self)
436             self.last_periodic_dump = ts
437
438
439 class RemoteWorkerSelectionPolicy(ABC):
440     def register_worker_pool(self, workers):
441         self.workers = workers
442
443     @abstractmethod
444     def is_worker_available(self) -> bool:
445         pass
446
447     @abstractmethod
448     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
449         pass
450
451
452 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
453     @overrides
454     def is_worker_available(self) -> bool:
455         for worker in self.workers:
456             if worker.count > 0:
457                 return True
458         return False
459
460     @overrides
461     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
462         grabbag = []
463         for worker in self.workers:
464             if worker.machine != machine_to_avoid:
465                 if worker.count > 0:
466                     for _ in range(worker.count * worker.weight):
467                         grabbag.append(worker)
468
469         if len(grabbag) == 0:
470             logger.debug(f'There are no available workers that avoid {machine_to_avoid}...')
471             for worker in self.workers:
472                 if worker.count > 0:
473                     for _ in range(worker.count * worker.weight):
474                         grabbag.append(worker)
475
476         if len(grabbag) == 0:
477             logger.warning('There are no available workers?!')
478             return None
479
480         worker = random.sample(grabbag, 1)[0]
481         assert worker.count > 0
482         worker.count -= 1
483         logger.debug(f'Chose worker {worker}')
484         return worker
485
486
487 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
488     def __init__(self) -> None:
489         self.index = 0
490
491     @overrides
492     def is_worker_available(self) -> bool:
493         for worker in self.workers:
494             if worker.count > 0:
495                 return True
496         return False
497
498     @overrides
499     def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
500         x = self.index
501         while True:
502             worker = self.workers[x]
503             if worker.count > 0:
504                 worker.count -= 1
505                 x += 1
506                 if x >= len(self.workers):
507                     x = 0
508                 self.index = x
509                 logger.debug(f'Selected worker {worker}')
510                 return worker
511             x += 1
512             if x >= len(self.workers):
513                 x = 0
514             if x == self.index:
515                 msg = 'Unexpectedly could not find a worker, retrying...'
516                 logger.warning(msg)
517                 return None
518
519
520 class RemoteExecutor(BaseExecutor):
521     def __init__(
522         self,
523         workers: List[RemoteWorkerRecord],
524         policy: RemoteWorkerSelectionPolicy,
525     ) -> None:
526         super().__init__()
527         self.workers = workers
528         self.policy = policy
529         self.worker_count = 0
530         for worker in self.workers:
531             self.worker_count += worker.count
532         if self.worker_count <= 0:
533             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
534             logger.critical(msg)
535             raise RemoteExecutorException(msg)
536         self.policy.register_worker_pool(self.workers)
537         self.cv = threading.Condition()
538         logger.debug(f'Creating {self.worker_count} local threads, one per remote worker.')
539         self._helper_executor = fut.ThreadPoolExecutor(
540             thread_name_prefix="remote_executor_helper",
541             max_workers=self.worker_count,
542         )
543         self.status = RemoteExecutorStatus(self.worker_count)
544         self.total_bundles_submitted = 0
545         self.backup_lock = threading.Lock()
546         self.last_backup = None
547         (
548             self.heartbeat_thread,
549             self.heartbeat_stop_event,
550         ) = self.run_periodic_heartbeat()
551         self.already_shutdown = False
552
553     @background_thread
554     def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
555         while not stop_event.is_set():
556             time.sleep(5.0)
557             logger.debug('Running periodic heartbeat code...')
558             self.heartbeat()
559         logger.debug('Periodic heartbeat thread shutting down.')
560
561     def heartbeat(self) -> None:
562         # Note: this is invoked on a background thread, not an
563         # executor thread.  Be careful what you do with it b/c it
564         # needs to get back and dump status again periodically.
565         with self.status.lock:
566             self.status.periodic_dump(self.total_bundles_submitted)
567
568             # Look for bundles to reschedule via executor.submit
569             if config.config['executors_schedule_remote_backups']:
570                 self.maybe_schedule_backup_bundles()
571
572     def maybe_schedule_backup_bundles(self):
573         assert self.status.lock.locked()
574         num_done = len(self.status.finished_bundle_timings)
575         num_idle_workers = self.worker_count - self.task_count
576         now = time.time()
577         if (
578             num_done > 2
579             and num_idle_workers > 1
580             and (self.last_backup is None or (now - self.last_backup > 9.0))
581             and self.backup_lock.acquire(blocking=False)
582         ):
583             try:
584                 assert self.backup_lock.locked()
585
586                 bundle_to_backup = None
587                 best_score = None
588                 for (
589                     worker,
590                     bundle_uuids,
591                 ) in self.status.in_flight_bundles_by_worker.items():
592
593                     # Prefer to schedule backups of bundles running on
594                     # slower machines.
595                     base_score = 0
596                     for record in self.workers:
597                         if worker.machine == record.machine:
598                             base_score = float(record.weight)
599                             base_score = 1.0 / base_score
600                             base_score *= 200.0
601                             base_score = int(base_score)
602                             break
603
604                     for uuid in bundle_uuids:
605                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
606                         if (
607                             bundle is not None
608                             and bundle.src_bundle is None
609                             and bundle.backup_bundles is not None
610                         ):
611                             score = base_score
612
613                             # Schedule backups of bundles running
614                             # longer; especially those that are
615                             # unexpectedly slow.
616                             start_ts = self.status.start_per_bundle[uuid]
617                             if start_ts is not None:
618                                 runtime = now - start_ts
619                                 score += runtime
620                                 logger.debug(f'score[{bundle}] => {score}  # latency boost')
621
622                                 if bundle.slower_than_local_p95:
623                                     score += runtime / 2
624                                     logger.debug(f'score[{bundle}] => {score}  # >worker p95')
625
626                                 if bundle.slower_than_global_p95:
627                                     score += runtime / 4
628                                     logger.debug(f'score[{bundle}] => {score}  # >global p95')
629
630                             # Prefer backups of bundles that don't
631                             # have backups already.
632                             backup_count = len(bundle.backup_bundles)
633                             if backup_count == 0:
634                                 score *= 2
635                             elif backup_count == 1:
636                                 score /= 2
637                             elif backup_count == 2:
638                                 score /= 8
639                             else:
640                                 score = 0
641                             logger.debug(
642                                 f'score[{bundle}] => {score}  # {backup_count} dup backup factor'
643                             )
644
645                             if score != 0 and (best_score is None or score > best_score):
646                                 bundle_to_backup = bundle
647                                 assert bundle is not None
648                                 assert bundle.backup_bundles is not None
649                                 assert bundle.src_bundle is None
650                                 best_score = score
651
652                 # Note: this is all still happening on the heartbeat
653                 # runner thread.  That's ok because
654                 # schedule_backup_for_bundle uses the executor to
655                 # submit the bundle again which will cause it to be
656                 # picked up by a worker thread and allow this thread
657                 # to return to run future heartbeats.
658                 if bundle_to_backup is not None:
659                     self.last_backup = now
660                     logger.info(
661                         f'=====> SCHEDULING BACKUP {bundle_to_backup} (score={best_score:.1f}) <====='
662                     )
663                     self.schedule_backup_for_bundle(bundle_to_backup)
664             finally:
665                 self.backup_lock.release()
666
667     def is_worker_available(self) -> bool:
668         return self.policy.is_worker_available()
669
670     def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
671         return self.policy.acquire_worker(machine_to_avoid)
672
673     def find_available_worker_or_block(self, machine_to_avoid: str = None) -> RemoteWorkerRecord:
674         with self.cv:
675             while not self.is_worker_available():
676                 self.cv.wait()
677             worker = self.acquire_worker(machine_to_avoid)
678             if worker is not None:
679                 return worker
680         msg = "We should never reach this point in the code"
681         logger.critical(msg)
682         raise Exception(msg)
683
684     def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
685         worker = bundle.worker
686         assert worker is not None
687         logger.debug(f'Released worker {worker}')
688         self.status.record_release_worker(
689             worker,
690             bundle.uuid,
691             was_cancelled,
692         )
693         with self.cv:
694             worker.count += 1
695             self.cv.notify()
696         self.adjust_task_count(-1)
697
698     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
699         with self.status.lock:
700             if bundle.is_cancelled.wait(timeout=0.0):
701                 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
702                 bundle.was_cancelled = True
703                 return True
704         return False
705
706     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
707         """Find a worker for bundle or block until one is available."""
708         self.adjust_task_count(+1)
709         uuid = bundle.uuid
710         hostname = bundle.hostname
711         avoid_machine = override_avoid_machine
712         is_original = bundle.src_bundle is None
713
714         # Try not to schedule a backup on the same host as the original.
715         if avoid_machine is None and bundle.src_bundle is not None:
716             avoid_machine = bundle.src_bundle.machine
717         worker = None
718         while worker is None:
719             worker = self.find_available_worker_or_block(avoid_machine)
720         assert worker is not None
721
722         # Ok, found a worker.
723         bundle.worker = worker
724         machine = bundle.machine = worker.machine
725         username = bundle.username = worker.username
726         self.status.record_acquire_worker(worker, uuid)
727         logger.debug(f'{bundle}: Running bundle on {worker}...')
728
729         # Before we do any work, make sure the bundle is still viable.
730         # It may have been some time between when it was submitted and
731         # now due to lack of worker availability and someone else may
732         # have already finished it.
733         if self.check_if_cancelled(bundle):
734             try:
735                 return self.process_work_result(bundle)
736             except Exception as e:
737                 logger.warning(f'{bundle}: bundle says it\'s cancelled upfront but no results?!')
738                 self.release_worker(bundle)
739                 if is_original:
740                     # Weird.  We are the original owner of this
741                     # bundle.  For it to have been cancelled, a backup
742                     # must have already started and completed before
743                     # we even for started.  Moreover, the backup says
744                     # it is done but we can't find the results it
745                     # should have copied over.  Reschedule the whole
746                     # thing.
747                     logger.exception(e)
748                     logger.error(
749                         f'{bundle}: We are the original owner thread and yet there are '
750                         + 'no results for this bundle.  This is unexpected and bad.'
751                     )
752                     return self.emergency_retry_nasty_bundle(bundle)
753                 else:
754                     # Expected(?).  We're a backup and our bundle is
755                     # cancelled before we even got started.  Something
756                     # went bad in process_work_result (I acutually don't
757                     # see what?) but probably not worth worrying
758                     # about.  Let the original thread worry about
759                     # either finding the results or complaining about
760                     # it.
761                     return None
762
763         # Send input code / data to worker machine if it's not local.
764         if hostname not in machine:
765             try:
766                 cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
767                 start_ts = time.time()
768                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
769                 run_silently(cmd)
770                 xfer_latency = time.time() - start_ts
771                 logger.debug(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.")
772             except Exception as e:
773                 self.release_worker(bundle)
774                 if is_original:
775                     # Weird.  We tried to copy the code to the worker and it failed...
776                     # And we're the original bundle.  We have to retry.
777                     logger.exception(e)
778                     logger.error(
779                         f"{bundle}: Failed to send instructions to the worker machine?! "
780                         + "This is not expected; we\'re the original bundle so this shouldn\'t "
781                         + "be a race condition.  Attempting an emergency retry..."
782                     )
783                     return self.emergency_retry_nasty_bundle(bundle)
784                 else:
785                     # This is actually expected; we're a backup.
786                     # There's a race condition where someone else
787                     # already finished the work and removed the source
788                     # code file before we could copy it.  No biggie.
789                     msg = f'{bundle}: Failed to send instructions to the worker machine... '
790                     msg += 'We\'re a backup and this may be caused by the original (or some '
791                     msg += 'other backup) already finishing this work.  Ignoring this.'
792                     logger.warning(msg)
793                     return None
794
795         # Kick off the work.  Note that if this fails we let
796         # wait_for_process deal with it.
797         self.status.record_processing_began(uuid)
798         cmd = (
799             f'{SSH} {bundle.username}@{bundle.machine} '
800             f'"source py38-venv/bin/activate &&'
801             f' /home/scott/lib/python_modules/remote_worker.py'
802             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
803         )
804         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
805         p = cmd_in_background(cmd, silent=True)
806         bundle.pid = p.pid
807         logger.debug(f'{bundle}: Local ssh process pid={p.pid}; remote worker is {machine}.')
808         return self.wait_for_process(p, bundle, 0)
809
810     def wait_for_process(
811         self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
812     ) -> Any:
813         machine = bundle.machine
814         assert p is not None
815         pid = p.pid
816         if depth > 3:
817             logger.error(
818                 f"I've gotten repeated errors waiting on this bundle; giving up on pid={pid}."
819             )
820             p.terminate()
821             self.release_worker(bundle)
822             return self.emergency_retry_nasty_bundle(bundle)
823
824         # Spin until either the ssh job we scheduled finishes the
825         # bundle or some backup worker signals that they finished it
826         # before we could.
827         while True:
828             try:
829                 p.wait(timeout=0.25)
830             except subprocess.TimeoutExpired:
831                 if self.check_if_cancelled(bundle):
832                     logger.info(f'{bundle}: looks like another worker finished bundle...')
833                     break
834             else:
835                 logger.info(f"{bundle}: pid {pid} ({machine}) is finished!")
836                 p = None
837                 break
838
839         # If we get here we believe the bundle is done; either the ssh
840         # subprocess finished (hopefully successfully) or we noticed
841         # that some other worker seems to have completed the bundle
842         # and we're bailing out.
843         try:
844             ret = self.process_work_result(bundle)
845             if ret is not None and p is not None:
846                 p.terminate()
847             return ret
848
849         # Something went wrong; e.g. we could not copy the results
850         # back, cleanup after ourselves on the remote machine, or
851         # unpickle the results we got from the remove machine.  If we
852         # still have an active ssh subprocess, keep waiting on it.
853         # Otherwise, time for an emergency reschedule.
854         except Exception as e:
855             logger.exception(e)
856             logger.error(f'{bundle}: Something unexpected just happened...')
857             if p is not None:
858                 msg = f"{bundle}: Failed to wrap up \"done\" bundle, re-waiting on active ssh."
859                 logger.warning(msg)
860                 return self.wait_for_process(p, bundle, depth + 1)
861             else:
862                 self.release_worker(bundle)
863                 return self.emergency_retry_nasty_bundle(bundle)
864
865     def process_work_result(self, bundle: BundleDetails) -> Any:
866         with self.status.lock:
867             is_original = bundle.src_bundle is None
868             was_cancelled = bundle.was_cancelled
869             username = bundle.username
870             machine = bundle.machine
871             result_file = bundle.result_file
872             code_file = bundle.code_file
873
874             # Whether original or backup, if we finished first we must
875             # fetch the results if the computation happened on a
876             # remote machine.
877             bundle.end_ts = time.time()
878             if not was_cancelled:
879                 assert bundle.machine is not None
880                 if bundle.hostname not in bundle.machine:
881                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
882                     logger.info(
883                         f"{bundle}: Fetching results back from {username}@{machine} via {cmd}"
884                     )
885
886                     # If either of these throw they are handled in
887                     # wait_for_process.
888                     attempts = 0
889                     while True:
890                         try:
891                             run_silently(cmd)
892                         except Exception as e:
893                             attempts += 1
894                             if attempts >= 3:
895                                 raise (e)
896                         else:
897                             break
898
899                     run_silently(
900                         f'{SSH} {username}@{machine}' f' "/bin/rm -f {code_file} {result_file}"'
901                     )
902                     logger.debug(f'Fetching results back took {time.time() - bundle.end_ts:.1f}s.')
903                 dur = bundle.end_ts - bundle.start_ts
904                 self.histogram.add_item(dur)
905
906         # Only the original worker should unpickle the file contents
907         # though since it's the only one whose result matters.  The
908         # original is also the only job that may delete result_file
909         # from disk.  Note that the original may have been cancelled
910         # if one of the backups finished first; it still must read the
911         # result from disk.
912         if is_original:
913             logger.debug(f"{bundle}: Unpickling {result_file}.")
914             try:
915                 with open(result_file, 'rb') as rb:
916                     serialized = rb.read()
917                 result = cloudpickle.loads(serialized)
918             except Exception as e:
919                 logger.exception(e)
920                 msg = f'Failed to load {result_file}... this is bad news.'
921                 logger.critical(msg)
922                 self.release_worker(bundle)
923
924                 # Re-raise the exception; the code in wait_for_process may
925                 # decide to emergency_retry_nasty_bundle here.
926                 raise Exception(e)
927             logger.debug(f'Removing local (master) {code_file} and {result_file}.')
928             os.remove(f'{result_file}')
929             os.remove(f'{code_file}')
930
931             # Notify any backups that the original is done so they
932             # should stop ASAP.  Do this whether or not we
933             # finished first since there could be more than one
934             # backup.
935             if bundle.backup_bundles is not None:
936                 for backup in bundle.backup_bundles:
937                     logger.debug(f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled')
938                     backup.is_cancelled.set()
939
940         # This is a backup job and, by now, we have already fetched
941         # the bundle results.
942         else:
943             # Backup results don't matter, they just need to leave the
944             # result file in the right place for their originals to
945             # read/unpickle later.
946             result = None
947
948             # Tell the original to stop if we finished first.
949             if not was_cancelled:
950                 orig_bundle = bundle.src_bundle
951                 assert orig_bundle is not None
952                 logger.debug(f'{bundle}: Notifying original {orig_bundle.uuid} we beat them to it.')
953                 orig_bundle.is_cancelled.set()
954         self.release_worker(bundle, was_cancelled=was_cancelled)
955         return result
956
957     def create_original_bundle(self, pickle, fname: str):
958         from string_utils import generate_uuid
959
960         uuid = generate_uuid(omit_dashes=True)
961         code_file = f'/tmp/{uuid}.code.bin'
962         result_file = f'/tmp/{uuid}.result.bin'
963
964         logger.debug(f'Writing pickled code to {code_file}')
965         with open(f'{code_file}', 'wb') as wb:
966             wb.write(pickle)
967
968         bundle = BundleDetails(
969             pickled_code=pickle,
970             uuid=uuid,
971             fname=fname,
972             worker=None,
973             username=None,
974             machine=None,
975             hostname=platform.node(),
976             code_file=code_file,
977             result_file=result_file,
978             pid=0,
979             start_ts=time.time(),
980             end_ts=0.0,
981             slower_than_local_p95=False,
982             slower_than_global_p95=False,
983             src_bundle=None,
984             is_cancelled=threading.Event(),
985             was_cancelled=False,
986             backup_bundles=[],
987             failure_count=0,
988         )
989         self.status.record_bundle_details(bundle)
990         logger.debug(f'{bundle}: Created an original bundle')
991         return bundle
992
993     def create_backup_bundle(self, src_bundle: BundleDetails):
994         assert src_bundle.backup_bundles is not None
995         n = len(src_bundle.backup_bundles)
996         uuid = src_bundle.uuid + f'_backup#{n}'
997
998         backup_bundle = BundleDetails(
999             pickled_code=src_bundle.pickled_code,
1000             uuid=uuid,
1001             fname=src_bundle.fname,
1002             worker=None,
1003             username=None,
1004             machine=None,
1005             hostname=src_bundle.hostname,
1006             code_file=src_bundle.code_file,
1007             result_file=src_bundle.result_file,
1008             pid=0,
1009             start_ts=time.time(),
1010             end_ts=0.0,
1011             slower_than_local_p95=False,
1012             slower_than_global_p95=False,
1013             src_bundle=src_bundle,
1014             is_cancelled=threading.Event(),
1015             was_cancelled=False,
1016             backup_bundles=None,  # backup backups not allowed
1017             failure_count=0,
1018         )
1019         src_bundle.backup_bundles.append(backup_bundle)
1020         self.status.record_bundle_details_already_locked(backup_bundle)
1021         logger.debug(f'{backup_bundle}: Created a backup bundle')
1022         return backup_bundle
1023
1024     def schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1025         assert self.status.lock.locked()
1026         assert src_bundle is not None
1027         backup_bundle = self.create_backup_bundle(src_bundle)
1028         logger.debug(
1029             f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
1030         )
1031         self._helper_executor.submit(self.launch, backup_bundle)
1032
1033         # Results from backups don't matter; if they finish first
1034         # they will move the result_file to this machine and let
1035         # the original pick them up and unpickle them.
1036
1037     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> Optional[fut.Future]:
1038         is_original = bundle.src_bundle is None
1039         bundle.worker = None
1040         avoid_last_machine = bundle.machine
1041         bundle.machine = None
1042         bundle.username = None
1043         bundle.failure_count += 1
1044         if is_original:
1045             retry_limit = 3
1046         else:
1047             retry_limit = 2
1048
1049         if bundle.failure_count > retry_limit:
1050             logger.error(
1051                 f'{bundle}: Tried this bundle too many times already ({retry_limit}x); giving up.'
1052             )
1053             if is_original:
1054                 raise RemoteExecutorException(
1055                     f'{bundle}: This bundle can\'t be completed despite several backups and retries'
1056                 )
1057             else:
1058                 logger.error(
1059                     f'{bundle}: At least it\'s only a backup; better luck with the others.'
1060                 )
1061             return None
1062         else:
1063             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1064             logger.warning(msg)
1065             warnings.warn(msg)
1066             return self.launch(bundle, avoid_last_machine)
1067
1068     @overrides
1069     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1070         if self.already_shutdown:
1071             raise Exception('Submitted work after shutdown.')
1072         pickle = make_cloud_pickle(function, *args, **kwargs)
1073         bundle = self.create_original_bundle(pickle, function.__name__)
1074         self.total_bundles_submitted += 1
1075         return self._helper_executor.submit(self.launch, bundle)
1076
1077     @overrides
1078     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
1079         if not self.already_shutdown:
1080             logging.debug(f'Shutting down RemoteExecutor {self.title}')
1081             self.heartbeat_stop_event.set()
1082             self.heartbeat_thread.join()
1083             self._helper_executor.shutdown(wait)
1084             if not quiet:
1085                 print(self.histogram.__repr__(label_formatter='%ds'))
1086             self.already_shutdown = True
1087
1088
1089 @singleton
1090 class DefaultExecutors(object):
1091     def __init__(self):
1092         self.thread_executor: Optional[ThreadExecutor] = None
1093         self.process_executor: Optional[ProcessExecutor] = None
1094         self.remote_executor: Optional[RemoteExecutor] = None
1095
1096     def ping(self, host) -> bool:
1097         logger.debug(f'RUN> ping -c 1 {host}')
1098         try:
1099             x = cmd_with_timeout(f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0)
1100             return x == 0
1101         except Exception:
1102             return False
1103
1104     def thread_pool(self) -> ThreadExecutor:
1105         if self.thread_executor is None:
1106             self.thread_executor = ThreadExecutor()
1107         return self.thread_executor
1108
1109     def process_pool(self) -> ProcessExecutor:
1110         if self.process_executor is None:
1111             self.process_executor = ProcessExecutor()
1112         return self.process_executor
1113
1114     def remote_pool(self) -> RemoteExecutor:
1115         if self.remote_executor is None:
1116             logger.info('Looking for some helper machines...')
1117             pool: List[RemoteWorkerRecord] = []
1118             if self.ping('cheetah.house'):
1119                 logger.info('Found cheetah.house')
1120                 pool.append(
1121                     RemoteWorkerRecord(
1122                         username='scott',
1123                         machine='cheetah.house',
1124                         weight=30,
1125                         count=6,
1126                     ),
1127                 )
1128             if self.ping('meerkat.cabin'):
1129                 logger.info('Found meerkat.cabin')
1130                 pool.append(
1131                     RemoteWorkerRecord(
1132                         username='scott',
1133                         machine='meerkat.cabin',
1134                         weight=12,
1135                         count=2,
1136                     ),
1137                 )
1138             if self.ping('wannabe.house'):
1139                 logger.info('Found wannabe.house')
1140                 pool.append(
1141                     RemoteWorkerRecord(
1142                         username='scott',
1143                         machine='wannabe.house',
1144                         weight=25,
1145                         count=10,
1146                     ),
1147                 )
1148             if self.ping('puma.cabin'):
1149                 logger.info('Found puma.cabin')
1150                 pool.append(
1151                     RemoteWorkerRecord(
1152                         username='scott',
1153                         machine='puma.cabin',
1154                         weight=30,
1155                         count=6,
1156                     ),
1157                 )
1158             if self.ping('backup.house'):
1159                 logger.info('Found backup.house')
1160                 pool.append(
1161                     RemoteWorkerRecord(
1162                         username='scott',
1163                         machine='backup.house',
1164                         weight=8,
1165                         count=2,
1166                     ),
1167                 )
1168
1169             # The controller machine has a lot to do; go easy on it.
1170             for record in pool:
1171                 if record.machine == platform.node() and record.count > 1:
1172                     logger.info(f'Reducing workload for {record.machine}.')
1173                     record.count = 1
1174
1175             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1176             policy.register_worker_pool(pool)
1177             self.remote_executor = RemoteExecutor(pool, policy)
1178         return self.remote_executor
1179
1180     def shutdown(self) -> None:
1181         if self.thread_executor is not None:
1182             self.thread_executor.shutdown(wait=True, quiet=True)
1183             self.thread_executor = None
1184         if self.process_executor is not None:
1185             self.process_executor.shutdown(wait=True, quiet=True)
1186             self.process_executor = None
1187         if self.remote_executor is not None:
1188             self.remote_executor.shutdown(wait=True, quiet=True)
1189             self.remote_executor = None