Loosen backup policy and cleanup code a little.
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3
4 """Defines three executors: a thread executor for doing work using a
5 threadpool, a process executor for doing work in other processes on
6 the same machine and a remote executor for farming out work to other
7 machines.
8
9 Also defines DefaultExecutors which is a container for references to
10 global executors / worker pools with automatic shutdown semantics."""
11
12 from __future__ import annotations
13 import concurrent.futures as fut
14 import logging
15 import os
16 import platform
17 import random
18 import subprocess
19 import threading
20 import time
21 import warnings
22 from abc import ABC, abstractmethod
23 from collections import defaultdict
24 from dataclasses import dataclass
25 from typing import Any, Callable, Dict, List, Optional, Set
26
27 import cloudpickle  # type: ignore
28 import numpy
29 from overrides import overrides
30
31 import argparse_utils
32 import config
33 import histogram as hist
34 import string_utils
35 from ansi import bg, fg, reset, underline
36 from decorator_utils import singleton
37 from exec_utils import cmd_in_background, cmd_with_timeout, run_silently
38 from thread_utils import background_thread
39
40 logger = logging.getLogger(__name__)
41
42 parser = config.add_commandline_args(
43     f"Executors ({__file__})", "Args related to processing executors."
44 )
45 parser.add_argument(
46     '--executors_threadpool_size',
47     type=int,
48     metavar='#THREADS',
49     help='Number of threads in the default threadpool, leave unset for default',
50     default=None,
51 )
52 parser.add_argument(
53     '--executors_processpool_size',
54     type=int,
55     metavar='#PROCESSES',
56     help='Number of processes in the default processpool, leave unset for default',
57     default=None,
58 )
59 parser.add_argument(
60     '--executors_schedule_remote_backups',
61     default=True,
62     action=argparse_utils.ActionNoYes,
63     help='Should we schedule duplicative backup work if a remote bundle is slow',
64 )
65 parser.add_argument(
66     '--executors_max_bundle_failures',
67     type=int,
68     default=3,
69     metavar='#FAILURES',
70     help='Maximum number of failures before giving up on a bundle',
71 )
72
73 SSH = '/usr/bin/ssh -oForwardX11=no'
74 SCP = '/usr/bin/scp -C'
75
76
77 def make_cloud_pickle(fun, *args, **kwargs):
78     logger.debug("Making cloudpickled bundle at %s", fun.__name__)
79     return cloudpickle.dumps((fun, args, kwargs))
80
81
82 class BaseExecutor(ABC):
83     """The base executor interface definition."""
84
85     def __init__(self, *, title=''):
86         self.title = title
87         self.histogram = hist.SimpleHistogram(
88             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
89         )
90         self.task_count = 0
91
92     @abstractmethod
93     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
94         pass
95
96     @abstractmethod
97     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
98         pass
99
100     def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
101         """Shutdown the executor and return True if the executor is idle
102         (i.e. there are no pending or active tasks).  Return False
103         otherwise.  Note: this should only be called by the launcher
104         process.
105
106         """
107         if self.task_count == 0:
108             self.shutdown(wait=True, quiet=quiet)
109             return True
110         return False
111
112     def adjust_task_count(self, delta: int) -> None:
113         """Change the task count.  Note: do not call this method from a
114         worker, it should only be called by the launcher process /
115         thread / machine.
116
117         """
118         self.task_count += delta
119         logger.debug('Adjusted task count by %d to %d.', delta, self.task_count)
120
121     def get_task_count(self) -> int:
122         """Change the task count.  Note: do not call this method from a
123         worker, it should only be called by the launcher process /
124         thread / machine.
125
126         """
127         return self.task_count
128
129
130 class ThreadExecutor(BaseExecutor):
131     """A threadpool executor instance."""
132
133     def __init__(self, max_workers: Optional[int] = None):
134         super().__init__()
135         workers = None
136         if max_workers is not None:
137             workers = max_workers
138         elif 'executors_threadpool_size' in config.config:
139             workers = config.config['executors_threadpool_size']
140         logger.debug('Creating threadpool executor with %d workers', workers)
141         self._thread_pool_executor = fut.ThreadPoolExecutor(
142             max_workers=workers, thread_name_prefix="thread_executor_helper"
143         )
144         self.already_shutdown = False
145
146     # This is run on a different thread; do not adjust task count here.
147     @staticmethod
148     def run_local_bundle(fun, *args, **kwargs):
149         logger.debug("Running local bundle at %s", fun.__name__)
150         result = fun(*args, **kwargs)
151         return result
152
153     @overrides
154     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
155         if self.already_shutdown:
156             raise Exception('Submitted work after shutdown.')
157         self.adjust_task_count(+1)
158         newargs = []
159         newargs.append(function)
160         for arg in args:
161             newargs.append(arg)
162         start = time.time()
163         result = self._thread_pool_executor.submit(
164             ThreadExecutor.run_local_bundle, *newargs, **kwargs
165         )
166         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
167         result.add_done_callback(lambda _: self.adjust_task_count(-1))
168         return result
169
170     @overrides
171     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
172         if not self.already_shutdown:
173             logger.debug('Shutting down threadpool executor %s', self.title)
174             self._thread_pool_executor.shutdown(wait)
175             if not quiet:
176                 print(self.histogram.__repr__(label_formatter='%ds'))
177             self.already_shutdown = True
178
179
180 class ProcessExecutor(BaseExecutor):
181     """A processpool executor."""
182
183     def __init__(self, max_workers=None):
184         super().__init__()
185         workers = None
186         if max_workers is not None:
187             workers = max_workers
188         elif 'executors_processpool_size' in config.config:
189             workers = config.config['executors_processpool_size']
190         logger.debug('Creating processpool executor with %d workers.', workers)
191         self._process_executor = fut.ProcessPoolExecutor(
192             max_workers=workers,
193         )
194         self.already_shutdown = False
195
196     # This is run in another process; do not adjust task count here.
197     @staticmethod
198     def run_cloud_pickle(pickle):
199         fun, args, kwargs = cloudpickle.loads(pickle)
200         logger.debug("Running pickled bundle at %s", fun.__name__)
201         result = fun(*args, **kwargs)
202         return result
203
204     @overrides
205     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
206         if self.already_shutdown:
207             raise Exception('Submitted work after shutdown.')
208         start = time.time()
209         self.adjust_task_count(+1)
210         pickle = make_cloud_pickle(function, *args, **kwargs)
211         result = self._process_executor.submit(ProcessExecutor.run_cloud_pickle, pickle)
212         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
213         result.add_done_callback(lambda _: self.adjust_task_count(-1))
214         return result
215
216     @overrides
217     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
218         if not self.already_shutdown:
219             logger.debug('Shutting down processpool executor %s', self.title)
220             self._process_executor.shutdown(wait)
221             if not quiet:
222                 print(self.histogram.__repr__(label_formatter='%ds'))
223             self.already_shutdown = True
224
225     def __getstate__(self):
226         state = self.__dict__.copy()
227         state['_process_executor'] = None
228         return state
229
230
231 class RemoteExecutorException(Exception):
232     """Thrown when a bundle cannot be executed despite several retries."""
233
234     pass
235
236
237 @dataclass
238 class RemoteWorkerRecord:
239     """A record of info about a remote worker."""
240
241     username: str
242     machine: str
243     weight: int
244     count: int
245
246     def __hash__(self):
247         return hash((self.username, self.machine))
248
249     def __repr__(self):
250         return f'{self.username}@{self.machine}'
251
252
253 @dataclass
254 class BundleDetails:
255     """All info necessary to define some unit of work that needs to be
256     done, where it is being run, its state, whether it is an original
257     bundle of a backup bundle, how many times it has failed, etc...
258
259     """
260
261     pickled_code: bytes
262     uuid: str
263     fname: str
264     worker: Optional[RemoteWorkerRecord]
265     username: Optional[str]
266     machine: Optional[str]
267     hostname: str
268     code_file: str
269     result_file: str
270     pid: int
271     start_ts: float
272     end_ts: float
273     slower_than_local_p95: bool
274     slower_than_global_p95: bool
275     src_bundle: Optional[BundleDetails]
276     is_cancelled: threading.Event
277     was_cancelled: bool
278     backup_bundles: Optional[List[BundleDetails]]
279     failure_count: int
280
281     def __repr__(self):
282         uuid = self.uuid
283         if uuid[-9:-2] == '_backup':
284             uuid = uuid[:-9]
285             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
286         else:
287             suffix = uuid[-6:]
288
289         colorz = [
290             fg('violet red'),
291             fg('red'),
292             fg('orange'),
293             fg('peach orange'),
294             fg('yellow'),
295             fg('marigold yellow'),
296             fg('green yellow'),
297             fg('tea green'),
298             fg('cornflower blue'),
299             fg('turquoise blue'),
300             fg('tropical blue'),
301             fg('lavender purple'),
302             fg('medium purple'),
303         ]
304         c = colorz[int(uuid[-2:], 16) % len(colorz)]
305         fname = self.fname if self.fname is not None else 'nofname'
306         machine = self.machine if self.machine is not None else 'nomachine'
307         return f'{c}{suffix}/{fname}/{machine}{reset()}'
308
309
310 class RemoteExecutorStatus:
311     """A status 'scoreboard' for a remote executor."""
312
313     def __init__(self, total_worker_count: int) -> None:
314         self.worker_count: int = total_worker_count
315         self.known_workers: Set[RemoteWorkerRecord] = set()
316         self.start_time: float = time.time()
317         self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
318         self.end_per_bundle: Dict[str, float] = defaultdict(float)
319         self.finished_bundle_timings_per_worker: Dict[RemoteWorkerRecord, List[float]] = {}
320         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
321         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
322         self.finished_bundle_timings: List[float] = []
323         self.last_periodic_dump: Optional[float] = None
324         self.total_bundles_submitted: int = 0
325
326         # Protects reads and modification using self.  Also used
327         # as a memory fence for modifications to bundle.
328         self.lock: threading.Lock = threading.Lock()
329
330     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
331         with self.lock:
332             self.record_acquire_worker_already_locked(worker, uuid)
333
334     def record_acquire_worker_already_locked(self, worker: RemoteWorkerRecord, uuid: str) -> None:
335         assert self.lock.locked()
336         self.known_workers.add(worker)
337         self.start_per_bundle[uuid] = None
338         x = self.in_flight_bundles_by_worker.get(worker, set())
339         x.add(uuid)
340         self.in_flight_bundles_by_worker[worker] = x
341
342     def record_bundle_details(self, details: BundleDetails) -> None:
343         with self.lock:
344             self.record_bundle_details_already_locked(details)
345
346     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
347         assert self.lock.locked()
348         self.bundle_details_by_uuid[details.uuid] = details
349
350     def record_release_worker(
351         self,
352         worker: RemoteWorkerRecord,
353         uuid: str,
354         was_cancelled: bool,
355     ) -> None:
356         with self.lock:
357             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
358
359     def record_release_worker_already_locked(
360         self,
361         worker: RemoteWorkerRecord,
362         uuid: str,
363         was_cancelled: bool,
364     ) -> None:
365         assert self.lock.locked()
366         ts = time.time()
367         self.end_per_bundle[uuid] = ts
368         self.in_flight_bundles_by_worker[worker].remove(uuid)
369         if not was_cancelled:
370             start = self.start_per_bundle[uuid]
371             assert start is not None
372             bundle_latency = ts - start
373             x = self.finished_bundle_timings_per_worker.get(worker, [])
374             x.append(bundle_latency)
375             self.finished_bundle_timings_per_worker[worker] = x
376             self.finished_bundle_timings.append(bundle_latency)
377
378     def record_processing_began(self, uuid: str):
379         with self.lock:
380             self.start_per_bundle[uuid] = time.time()
381
382     def total_in_flight(self) -> int:
383         assert self.lock.locked()
384         total_in_flight = 0
385         for worker in self.known_workers:
386             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
387         return total_in_flight
388
389     def total_idle(self) -> int:
390         assert self.lock.locked()
391         return self.worker_count - self.total_in_flight()
392
393     def __repr__(self):
394         assert self.lock.locked()
395         ts = time.time()
396         total_finished = len(self.finished_bundle_timings)
397         total_in_flight = self.total_in_flight()
398         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
399         qall = None
400         if len(self.finished_bundle_timings) > 1:
401             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
402             ret += (
403                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, total={ts-self.start_time:.1f}s, '
404                 f'✅={total_finished}/{self.total_bundles_submitted}, '
405                 f'💻n={total_in_flight}/{self.worker_count}\n'
406             )
407         else:
408             ret += (
409                 f'⏱={ts-self.start_time:.1f}s, '
410                 f'✅={total_finished}/{self.total_bundles_submitted}, '
411                 f'💻n={total_in_flight}/{self.worker_count}\n'
412             )
413
414         for worker in self.known_workers:
415             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
416             timings = self.finished_bundle_timings_per_worker.get(worker, [])
417             count = len(timings)
418             qworker = None
419             if count > 1:
420                 qworker = numpy.quantile(timings, [0.5, 0.95])
421                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
422             else:
423                 ret += '\n'
424             if count > 0:
425                 ret += f'    ...finished {count} total bundle(s) so far\n'
426             in_flight = len(self.in_flight_bundles_by_worker[worker])
427             if in_flight > 0:
428                 ret += f'    ...{in_flight} bundles currently in flight:\n'
429                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
430                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
431                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
432                     if self.start_per_bundle[bundle_uuid] is not None:
433                         sec = ts - self.start_per_bundle[bundle_uuid]
434                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
435                     else:
436                         ret += f'       {details} setting up / copying data...'
437                         sec = 0.0
438
439                     if qworker is not None:
440                         if sec > qworker[1]:
441                             ret += f'{bg("red")}>💻p95{reset()} '
442                             if details is not None:
443                                 details.slower_than_local_p95 = True
444                         else:
445                             if details is not None:
446                                 details.slower_than_local_p95 = False
447
448                     if qall is not None:
449                         if sec > qall[1]:
450                             ret += f'{bg("red")}>∀p95{reset()} '
451                             if details is not None:
452                                 details.slower_than_global_p95 = True
453                         else:
454                             details.slower_than_global_p95 = False
455                     ret += '\n'
456         return ret
457
458     def periodic_dump(self, total_bundles_submitted: int) -> None:
459         assert self.lock.locked()
460         self.total_bundles_submitted = total_bundles_submitted
461         ts = time.time()
462         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
463             print(self)
464             self.last_periodic_dump = ts
465
466
467 class RemoteWorkerSelectionPolicy(ABC):
468     """A policy for selecting a remote worker base class."""
469
470     def __init__(self):
471         self.workers: Optional[List[RemoteWorkerRecord]] = None
472
473     def register_worker_pool(self, workers: List[RemoteWorkerRecord]):
474         self.workers = workers
475
476     @abstractmethod
477     def is_worker_available(self) -> bool:
478         pass
479
480     @abstractmethod
481     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
482         pass
483
484
485 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
486     """A remote worker selector that uses weighted RNG."""
487
488     @overrides
489     def is_worker_available(self) -> bool:
490         if self.workers:
491             for worker in self.workers:
492                 if worker.count > 0:
493                     return True
494         return False
495
496     @overrides
497     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
498         grabbag = []
499         if self.workers:
500             for worker in self.workers:
501                 if worker.machine != machine_to_avoid:
502                     if worker.count > 0:
503                         for _ in range(worker.count * worker.weight):
504                             grabbag.append(worker)
505
506         if len(grabbag) == 0:
507             logger.debug('There are no available workers that avoid %s', machine_to_avoid)
508             if self.workers:
509                 for worker in self.workers:
510                     if worker.count > 0:
511                         for _ in range(worker.count * worker.weight):
512                             grabbag.append(worker)
513
514         if len(grabbag) == 0:
515             logger.warning('There are no available workers?!')
516             return None
517
518         worker = random.sample(grabbag, 1)[0]
519         assert worker.count > 0
520         worker.count -= 1
521         logger.debug('Selected worker %s', worker)
522         return worker
523
524
525 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
526     """A remote worker selector that just round robins."""
527
528     def __init__(self) -> None:
529         super().__init__()
530         self.index = 0
531
532     @overrides
533     def is_worker_available(self) -> bool:
534         if self.workers:
535             for worker in self.workers:
536                 if worker.count > 0:
537                     return True
538         return False
539
540     @overrides
541     def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
542         if self.workers:
543             x = self.index
544             while True:
545                 worker = self.workers[x]
546                 if worker.count > 0:
547                     worker.count -= 1
548                     x += 1
549                     if x >= len(self.workers):
550                         x = 0
551                     self.index = x
552                     logger.debug('Selected worker %s', worker)
553                     return worker
554                 x += 1
555                 if x >= len(self.workers):
556                     x = 0
557                 if x == self.index:
558                     logger.warning('Unexpectedly could not find a worker, retrying...')
559                     return None
560         return None
561
562
563 class RemoteExecutor(BaseExecutor):
564     """A remote work executor."""
565
566     def __init__(
567         self,
568         workers: List[RemoteWorkerRecord],
569         policy: RemoteWorkerSelectionPolicy,
570     ) -> None:
571         super().__init__()
572         self.workers = workers
573         self.policy = policy
574         self.worker_count = 0
575         for worker in self.workers:
576             self.worker_count += worker.count
577         if self.worker_count <= 0:
578             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
579             logger.critical(msg)
580             raise RemoteExecutorException(msg)
581         self.policy.register_worker_pool(self.workers)
582         self.cv = threading.Condition()
583         logger.debug('Creating %d local threads, one per remote worker.', self.worker_count)
584         self._helper_executor = fut.ThreadPoolExecutor(
585             thread_name_prefix="remote_executor_helper",
586             max_workers=self.worker_count,
587         )
588         self.status = RemoteExecutorStatus(self.worker_count)
589         self.total_bundles_submitted = 0
590         self.backup_lock = threading.Lock()
591         self.last_backup = None
592         (
593             self.heartbeat_thread,
594             self.heartbeat_stop_event,
595         ) = self.run_periodic_heartbeat()
596         self.already_shutdown = False
597
598     @background_thread
599     def run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
600         while not stop_event.is_set():
601             time.sleep(5.0)
602             logger.debug('Running periodic heartbeat code...')
603             self.heartbeat()
604         logger.debug('Periodic heartbeat thread shutting down.')
605
606     def heartbeat(self) -> None:
607         # Note: this is invoked on a background thread, not an
608         # executor thread.  Be careful what you do with it b/c it
609         # needs to get back and dump status again periodically.
610         with self.status.lock:
611             self.status.periodic_dump(self.total_bundles_submitted)
612
613             # Look for bundles to reschedule via executor.submit
614             if config.config['executors_schedule_remote_backups']:
615                 self.maybe_schedule_backup_bundles()
616
617     def maybe_schedule_backup_bundles(self):
618         assert self.status.lock.locked()
619         num_done = len(self.status.finished_bundle_timings)
620         num_idle_workers = self.worker_count - self.task_count
621         now = time.time()
622         if (
623             num_done >= 2
624             and num_idle_workers > 0
625             and (self.last_backup is None or (now - self.last_backup > 9.0))
626             and self.backup_lock.acquire(blocking=False)
627         ):
628             try:
629                 assert self.backup_lock.locked()
630
631                 bundle_to_backup = None
632                 best_score = None
633                 for (
634                     worker,
635                     bundle_uuids,
636                 ) in self.status.in_flight_bundles_by_worker.items():
637
638                     # Prefer to schedule backups of bundles running on
639                     # slower machines.
640                     base_score = 0
641                     for record in self.workers:
642                         if worker.machine == record.machine:
643                             base_score = float(record.weight)
644                             base_score = 1.0 / base_score
645                             base_score *= 200.0
646                             base_score = int(base_score)
647                             break
648
649                     for uuid in bundle_uuids:
650                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
651                         if (
652                             bundle is not None
653                             and bundle.src_bundle is None
654                             and bundle.backup_bundles is not None
655                         ):
656                             score = base_score
657
658                             # Schedule backups of bundles running
659                             # longer; especially those that are
660                             # unexpectedly slow.
661                             start_ts = self.status.start_per_bundle[uuid]
662                             if start_ts is not None:
663                                 runtime = now - start_ts
664                                 score += runtime
665                                 logger.debug('score[%s] => %.1f  # latency boost', bundle, score)
666
667                                 if bundle.slower_than_local_p95:
668                                     score += runtime / 2
669                                     logger.debug('score[%s] => %.1f  # >worker p95', bundle, score)
670
671                                 if bundle.slower_than_global_p95:
672                                     score += runtime / 4
673                                     logger.debug('score[%s] => %.1f  # >global p95', bundle, score)
674
675                             # Prefer backups of bundles that don't
676                             # have backups already.
677                             backup_count = len(bundle.backup_bundles)
678                             if backup_count == 0:
679                                 score *= 2
680                             elif backup_count == 1:
681                                 score /= 2
682                             elif backup_count == 2:
683                                 score /= 8
684                             else:
685                                 score = 0
686                             logger.debug(
687                                 'score[%s] => %.1f  # {backup_count} dup backup factor',
688                                 bundle,
689                                 score,
690                             )
691
692                             if score != 0 and (best_score is None or score > best_score):
693                                 bundle_to_backup = bundle
694                                 assert bundle is not None
695                                 assert bundle.backup_bundles is not None
696                                 assert bundle.src_bundle is None
697                                 best_score = score
698
699                 # Note: this is all still happening on the heartbeat
700                 # runner thread.  That's ok because
701                 # schedule_backup_for_bundle uses the executor to
702                 # submit the bundle again which will cause it to be
703                 # picked up by a worker thread and allow this thread
704                 # to return to run future heartbeats.
705                 if bundle_to_backup is not None:
706                     self.last_backup = now
707                     logger.info(
708                         '=====> SCHEDULING BACKUP %s (score=%.1f) <=====',
709                         bundle_to_backup,
710                         best_score,
711                     )
712                     self.schedule_backup_for_bundle(bundle_to_backup)
713             finally:
714                 self.backup_lock.release()
715
716     def is_worker_available(self) -> bool:
717         return self.policy.is_worker_available()
718
719     def acquire_worker(self, machine_to_avoid: str = None) -> Optional[RemoteWorkerRecord]:
720         return self.policy.acquire_worker(machine_to_avoid)
721
722     def find_available_worker_or_block(self, machine_to_avoid: str = None) -> RemoteWorkerRecord:
723         with self.cv:
724             while not self.is_worker_available():
725                 self.cv.wait()
726             worker = self.acquire_worker(machine_to_avoid)
727             if worker is not None:
728                 return worker
729         msg = "We should never reach this point in the code"
730         logger.critical(msg)
731         raise Exception(msg)
732
733     def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
734         worker = bundle.worker
735         assert worker is not None
736         logger.debug('Released worker %s', worker)
737         self.status.record_release_worker(
738             worker,
739             bundle.uuid,
740             was_cancelled,
741         )
742         with self.cv:
743             worker.count += 1
744             self.cv.notify()
745         self.adjust_task_count(-1)
746
747     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
748         with self.status.lock:
749             if bundle.is_cancelled.wait(timeout=0.0):
750                 logger.debug('Bundle %s is cancelled, bail out.', bundle.uuid)
751                 bundle.was_cancelled = True
752                 return True
753         return False
754
755     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
756         """Find a worker for bundle or block until one is available."""
757
758         self.adjust_task_count(+1)
759         uuid = bundle.uuid
760         hostname = bundle.hostname
761         avoid_machine = override_avoid_machine
762         is_original = bundle.src_bundle is None
763
764         # Try not to schedule a backup on the same host as the original.
765         if avoid_machine is None and bundle.src_bundle is not None:
766             avoid_machine = bundle.src_bundle.machine
767         worker = None
768         while worker is None:
769             worker = self.find_available_worker_or_block(avoid_machine)
770         assert worker is not None
771
772         # Ok, found a worker.
773         bundle.worker = worker
774         machine = bundle.machine = worker.machine
775         username = bundle.username = worker.username
776         self.status.record_acquire_worker(worker, uuid)
777         logger.debug('%s: Running bundle on %s...', bundle, worker)
778
779         # Before we do any work, make sure the bundle is still viable.
780         # It may have been some time between when it was submitted and
781         # now due to lack of worker availability and someone else may
782         # have already finished it.
783         if self.check_if_cancelled(bundle):
784             try:
785                 return self.process_work_result(bundle)
786             except Exception as e:
787                 logger.warning('%s: bundle says it\'s cancelled upfront but no results?!', bundle)
788                 self.release_worker(bundle)
789                 if is_original:
790                     # Weird.  We are the original owner of this
791                     # bundle.  For it to have been cancelled, a backup
792                     # must have already started and completed before
793                     # we even for started.  Moreover, the backup says
794                     # it is done but we can't find the results it
795                     # should have copied over.  Reschedule the whole
796                     # thing.
797                     logger.exception(e)
798                     logger.error(
799                         '%s: We are the original owner thread and yet there are '
800                         'no results for this bundle.  This is unexpected and bad.',
801                         bundle,
802                     )
803                     return self.emergency_retry_nasty_bundle(bundle)
804                 else:
805                     # We're a backup and our bundle is cancelled
806                     # before we even got started.  Do nothing and let
807                     # the original bundle's thread worry about either
808                     # finding the results or complaining about it.
809                     return None
810
811         # Send input code / data to worker machine if it's not local.
812         if hostname not in machine:
813             try:
814                 cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
815                 start_ts = time.time()
816                 logger.info("%s: Copying work to %s via %s.", bundle, worker, cmd)
817                 run_silently(cmd)
818                 xfer_latency = time.time() - start_ts
819                 logger.debug("%s: Copying to %s took %.1fs.", bundle, worker, xfer_latency)
820             except Exception as e:
821                 self.release_worker(bundle)
822                 if is_original:
823                     # Weird.  We tried to copy the code to the worker
824                     # and it failed...  And we're the original bundle.
825                     # We have to retry.
826                     logger.exception(e)
827                     logger.error(
828                         "%s: Failed to send instructions to the worker machine?! "
829                         "This is not expected; we\'re the original bundle so this shouldn\'t "
830                         "be a race condition.  Attempting an emergency retry...",
831                         bundle,
832                     )
833                     return self.emergency_retry_nasty_bundle(bundle)
834                 else:
835                     # This is actually expected; we're a backup.
836                     # There's a race condition where someone else
837                     # already finished the work and removed the source
838                     # code_file before we could copy it.  Ignore.
839                     logger.warning(
840                         '%s: Failed to send instructions to the worker machine... '
841                         'We\'re a backup and this may be caused by the original (or '
842                         'some other backup) already finishing this work.  Ignoring.',
843                         bundle,
844                     )
845                     return None
846
847         # Kick off the work.  Note that if this fails we let
848         # wait_for_process deal with it.
849         self.status.record_processing_began(uuid)
850         cmd = (
851             f'{SSH} {bundle.username}@{bundle.machine} '
852             f'"source py38-venv/bin/activate &&'
853             f' /home/scott/lib/python_modules/remote_worker.py'
854             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
855         )
856         logger.debug('%s: Executing %s in the background to kick off work...', bundle, cmd)
857         p = cmd_in_background(cmd, silent=True)
858         bundle.pid = p.pid
859         logger.debug('%s: Local ssh process pid=%d; remote worker is %s.', bundle, p.pid, machine)
860         return self.wait_for_process(p, bundle, 0)
861
862     def wait_for_process(
863         self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
864     ) -> Any:
865         machine = bundle.machine
866         assert p is not None
867         pid = p.pid
868         if depth > 3:
869             logger.error(
870                 "I've gotten repeated errors waiting on this bundle; giving up on pid=%d", pid
871             )
872             p.terminate()
873             self.release_worker(bundle)
874             return self.emergency_retry_nasty_bundle(bundle)
875
876         # Spin until either the ssh job we scheduled finishes the
877         # bundle or some backup worker signals that they finished it
878         # before we could.
879         while True:
880             try:
881                 p.wait(timeout=0.25)
882             except subprocess.TimeoutExpired:
883                 if self.check_if_cancelled(bundle):
884                     logger.info('%s: looks like another worker finished bundle...', bundle)
885                     break
886             else:
887                 logger.info("%s: pid %d (%s) is finished!", bundle, pid, machine)
888                 p = None
889                 break
890
891         # If we get here we believe the bundle is done; either the ssh
892         # subprocess finished (hopefully successfully) or we noticed
893         # that some other worker seems to have completed the bundle
894         # and we're bailing out.
895         try:
896             ret = self.process_work_result(bundle)
897             if ret is not None and p is not None:
898                 p.terminate()
899             return ret
900
901         # Something went wrong; e.g. we could not copy the results
902         # back, cleanup after ourselves on the remote machine, or
903         # unpickle the results we got from the remove machine.  If we
904         # still have an active ssh subprocess, keep waiting on it.
905         # Otherwise, time for an emergency reschedule.
906         except Exception as e:
907             logger.exception(e)
908             logger.error('%s: Something unexpected just happened...', bundle)
909             if p is not None:
910                 logger.warning(
911                     "%s: Failed to wrap up \"done\" bundle, re-waiting on active ssh.", bundle
912                 )
913                 return self.wait_for_process(p, bundle, depth + 1)
914             else:
915                 self.release_worker(bundle)
916                 return self.emergency_retry_nasty_bundle(bundle)
917
918     def process_work_result(self, bundle: BundleDetails) -> Any:
919         with self.status.lock:
920             is_original = bundle.src_bundle is None
921             was_cancelled = bundle.was_cancelled
922             username = bundle.username
923             machine = bundle.machine
924             result_file = bundle.result_file
925             code_file = bundle.code_file
926
927             # Whether original or backup, if we finished first we must
928             # fetch the results if the computation happened on a
929             # remote machine.
930             bundle.end_ts = time.time()
931             if not was_cancelled:
932                 assert bundle.machine is not None
933                 if bundle.hostname not in bundle.machine:
934                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
935                     logger.info(
936                         "%s: Fetching results back from %s@%s via %s",
937                         bundle,
938                         username,
939                         machine,
940                         cmd,
941                     )
942
943                     # If either of these throw they are handled in
944                     # wait_for_process.
945                     attempts = 0
946                     while True:
947                         try:
948                             run_silently(cmd)
949                         except Exception as e:
950                             attempts += 1
951                             if attempts >= 3:
952                                 raise e
953                         else:
954                             break
955
956                     # Cleanup remote /tmp files.
957                     run_silently(
958                         f'{SSH} {username}@{machine}' f' "/bin/rm -f {code_file} {result_file}"'
959                     )
960                     logger.debug('Fetching results back took %.2fs', time.time() - bundle.end_ts)
961                 dur = bundle.end_ts - bundle.start_ts
962                 self.histogram.add_item(dur)
963
964         # Only the original worker should unpickle the file contents
965         # though since it's the only one whose result matters.  The
966         # original is also the only job that may delete result_file
967         # from disk.  Note that the original may have been cancelled
968         # if one of the backups finished first; it still must read the
969         # result from disk.  It still does that here with is_cancelled
970         # set.
971         if is_original:
972             logger.debug("%s: Unpickling %s.", bundle, result_file)
973             try:
974                 with open(result_file, 'rb') as rb:
975                     serialized = rb.read()
976                 result = cloudpickle.loads(serialized)
977             except Exception as e:
978                 logger.exception(e)
979                 logger.error('Failed to load %s... this is bad news.', result_file)
980                 self.release_worker(bundle)
981
982                 # Re-raise the exception; the code in wait_for_process may
983                 # decide to emergency_retry_nasty_bundle here.
984                 raise e
985             logger.debug('Removing local (master) %s and %s.', code_file, result_file)
986             os.remove(result_file)
987             os.remove(code_file)
988
989             # Notify any backups that the original is done so they
990             # should stop ASAP.  Do this whether or not we
991             # finished first since there could be more than one
992             # backup.
993             if bundle.backup_bundles is not None:
994                 for backup in bundle.backup_bundles:
995                     logger.debug(
996                         '%s: Notifying backup %s that it\'s cancelled', bundle, backup.uuid
997                     )
998                     backup.is_cancelled.set()
999
1000         # This is a backup job and, by now, we have already fetched
1001         # the bundle results.
1002         else:
1003             # Backup results don't matter, they just need to leave the
1004             # result file in the right place for their originals to
1005             # read/unpickle later.
1006             result = None
1007
1008             # Tell the original to stop if we finished first.
1009             if not was_cancelled:
1010                 orig_bundle = bundle.src_bundle
1011                 assert orig_bundle is not None
1012                 logger.debug(
1013                     '%s: Notifying original %s we beat them to it.', bundle, orig_bundle.uuid
1014                 )
1015                 orig_bundle.is_cancelled.set()
1016         self.release_worker(bundle, was_cancelled=was_cancelled)
1017         return result
1018
1019     def create_original_bundle(self, pickle, fname: str):
1020         uuid = string_utils.generate_uuid(omit_dashes=True)
1021         code_file = f'/tmp/{uuid}.code.bin'
1022         result_file = f'/tmp/{uuid}.result.bin'
1023
1024         logger.debug('Writing pickled code to %s', code_file)
1025         with open(code_file, 'wb') as wb:
1026             wb.write(pickle)
1027
1028         bundle = BundleDetails(
1029             pickled_code=pickle,
1030             uuid=uuid,
1031             fname=fname,
1032             worker=None,
1033             username=None,
1034             machine=None,
1035             hostname=platform.node(),
1036             code_file=code_file,
1037             result_file=result_file,
1038             pid=0,
1039             start_ts=time.time(),
1040             end_ts=0.0,
1041             slower_than_local_p95=False,
1042             slower_than_global_p95=False,
1043             src_bundle=None,
1044             is_cancelled=threading.Event(),
1045             was_cancelled=False,
1046             backup_bundles=[],
1047             failure_count=0,
1048         )
1049         self.status.record_bundle_details(bundle)
1050         logger.debug('%s: Created an original bundle', bundle)
1051         return bundle
1052
1053     def create_backup_bundle(self, src_bundle: BundleDetails):
1054         assert self.status.lock.locked()
1055         assert src_bundle.backup_bundles is not None
1056         n = len(src_bundle.backup_bundles)
1057         uuid = src_bundle.uuid + f'_backup#{n}'
1058
1059         backup_bundle = BundleDetails(
1060             pickled_code=src_bundle.pickled_code,
1061             uuid=uuid,
1062             fname=src_bundle.fname,
1063             worker=None,
1064             username=None,
1065             machine=None,
1066             hostname=src_bundle.hostname,
1067             code_file=src_bundle.code_file,
1068             result_file=src_bundle.result_file,
1069             pid=0,
1070             start_ts=time.time(),
1071             end_ts=0.0,
1072             slower_than_local_p95=False,
1073             slower_than_global_p95=False,
1074             src_bundle=src_bundle,
1075             is_cancelled=threading.Event(),
1076             was_cancelled=False,
1077             backup_bundles=None,  # backup backups not allowed
1078             failure_count=0,
1079         )
1080         src_bundle.backup_bundles.append(backup_bundle)
1081         self.status.record_bundle_details_already_locked(backup_bundle)
1082         logger.debug('%s: Created a backup bundle', backup_bundle)
1083         return backup_bundle
1084
1085     def schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1086         assert self.status.lock.locked()
1087         assert src_bundle is not None
1088         backup_bundle = self.create_backup_bundle(src_bundle)
1089         logger.debug(
1090             '%s/%s: Scheduling backup for execution...', backup_bundle.uuid, backup_bundle.fname
1091         )
1092         self._helper_executor.submit(self.launch, backup_bundle)
1093
1094         # Results from backups don't matter; if they finish first
1095         # they will move the result_file to this machine and let
1096         # the original pick them up and unpickle them (and return
1097         # a result).
1098
1099     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> Optional[fut.Future]:
1100         is_original = bundle.src_bundle is None
1101         bundle.worker = None
1102         avoid_last_machine = bundle.machine
1103         bundle.machine = None
1104         bundle.username = None
1105         bundle.failure_count += 1
1106         if is_original:
1107             retry_limit = 3
1108         else:
1109             retry_limit = 2
1110
1111         if bundle.failure_count > retry_limit:
1112             logger.error(
1113                 '%s: Tried this bundle too many times already (%dx); giving up.',
1114                 bundle,
1115                 retry_limit,
1116             )
1117             if is_original:
1118                 raise RemoteExecutorException(
1119                     f'{bundle}: This bundle can\'t be completed despite several backups and retries',
1120                 )
1121             else:
1122                 logger.error(
1123                     '%s: At least it\'s only a backup; better luck with the others.', bundle
1124                 )
1125             return None
1126         else:
1127             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1128             logger.warning(msg)
1129             warnings.warn(msg)
1130             return self.launch(bundle, avoid_last_machine)
1131
1132     @overrides
1133     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1134         if self.already_shutdown:
1135             raise Exception('Submitted work after shutdown.')
1136         pickle = make_cloud_pickle(function, *args, **kwargs)
1137         bundle = self.create_original_bundle(pickle, function.__name__)
1138         self.total_bundles_submitted += 1
1139         return self._helper_executor.submit(self.launch, bundle)
1140
1141     @overrides
1142     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
1143         if not self.already_shutdown:
1144             logging.debug('Shutting down RemoteExecutor %s', self.title)
1145             self.heartbeat_stop_event.set()
1146             self.heartbeat_thread.join()
1147             self._helper_executor.shutdown(wait)
1148             if not quiet:
1149                 print(self.histogram.__repr__(label_formatter='%ds'))
1150             self.already_shutdown = True
1151
1152
1153 @singleton
1154 class DefaultExecutors(object):
1155     """A container for a default thread, process and remote executor.
1156     These are not created until needed and we take care to clean up
1157     before process exit.
1158
1159     """
1160
1161     def __init__(self):
1162         self.thread_executor: Optional[ThreadExecutor] = None
1163         self.process_executor: Optional[ProcessExecutor] = None
1164         self.remote_executor: Optional[RemoteExecutor] = None
1165
1166     @staticmethod
1167     def ping(host) -> bool:
1168         logger.debug('RUN> ping -c 1 %s', host)
1169         try:
1170             x = cmd_with_timeout(f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0)
1171             return x == 0
1172         except Exception:
1173             return False
1174
1175     def thread_pool(self) -> ThreadExecutor:
1176         if self.thread_executor is None:
1177             self.thread_executor = ThreadExecutor()
1178         return self.thread_executor
1179
1180     def process_pool(self) -> ProcessExecutor:
1181         if self.process_executor is None:
1182             self.process_executor = ProcessExecutor()
1183         return self.process_executor
1184
1185     def remote_pool(self) -> RemoteExecutor:
1186         if self.remote_executor is None:
1187             logger.info('Looking for some helper machines...')
1188             pool: List[RemoteWorkerRecord] = []
1189             if self.ping('cheetah.house'):
1190                 logger.info('Found cheetah.house')
1191                 pool.append(
1192                     RemoteWorkerRecord(
1193                         username='scott',
1194                         machine='cheetah.house',
1195                         weight=24,
1196                         count=5,
1197                     ),
1198                 )
1199             if self.ping('meerkat.cabin'):
1200                 logger.info('Found meerkat.cabin')
1201                 pool.append(
1202                     RemoteWorkerRecord(
1203                         username='scott',
1204                         machine='meerkat.cabin',
1205                         weight=12,
1206                         count=2,
1207                     ),
1208                 )
1209             if self.ping('wannabe.house'):
1210                 logger.info('Found wannabe.house')
1211                 pool.append(
1212                     RemoteWorkerRecord(
1213                         username='scott',
1214                         machine='wannabe.house',
1215                         weight=14,
1216                         count=2,
1217                     ),
1218                 )
1219             if self.ping('puma.cabin'):
1220                 logger.info('Found puma.cabin')
1221                 pool.append(
1222                     RemoteWorkerRecord(
1223                         username='scott',
1224                         machine='puma.cabin',
1225                         weight=24,
1226                         count=5,
1227                     ),
1228                 )
1229             if self.ping('backup.house'):
1230                 logger.info('Found backup.house')
1231                 pool.append(
1232                     RemoteWorkerRecord(
1233                         username='scott',
1234                         machine='backup.house',
1235                         weight=9,
1236                         count=2,
1237                     ),
1238                 )
1239
1240             # The controller machine has a lot to do; go easy on it.
1241             for record in pool:
1242                 if record.machine == platform.node() and record.count > 1:
1243                     logger.info('Reducing workload for %s.', record.machine)
1244                     record.count = max(int(record.count / 2), 1)
1245
1246             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1247             policy.register_worker_pool(pool)
1248             self.remote_executor = RemoteExecutor(pool, policy)
1249         return self.remote_executor
1250
1251     def shutdown(self) -> None:
1252         if self.thread_executor is not None:
1253             self.thread_executor.shutdown(wait=True, quiet=True)
1254             self.thread_executor = None
1255         if self.process_executor is not None:
1256             self.process_executor.shutdown(wait=True, quiet=True)
1257             self.process_executor = None
1258         if self.remote_executor is not None:
1259             self.remote_executor.shutdown(wait=True, quiet=True)
1260             self.remote_executor = None