More work to improve documentation generated by sphinx. Also fixes
[pyutils.git] / src / pyutils / parallelize / executors.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3
4 # © Copyright 2021-2022, Scott Gasch
5
6 """
7 This module defines a :class:`BaseExecutor` interface and three
8 implementations:
9
10     - :class:`ThreadExecutor`
11     - :class:`ProcessExecutor`
12     - :class:`RemoteExecutor`
13
14 The :class:`ThreadExecutor` is used to dispatch work to background
15 threads in the same Python process for parallelized work.  Of course,
16 until the Global Interpreter Lock (GIL) bottleneck is resolved, this
17 is not terribly useful for compute-bound code.  But it's good for
18 work that is mostly I/O bound.
19
20 The :class:`ProcessExecutor` is used to dispatch work to other
21 processes on the same machine and is more useful for compute-bound
22 workloads.
23
24 The :class:`RemoteExecutor` is used in conjunection with `ssh`,
25 the `cloudpickle` dependency, and `remote_worker.py <https://wannabe.guru.org/gitweb/?p=pyutils.git;a=blob_plain;f=src/pyutils/remote_worker.py;hb=HEAD>`_ file
26 to dispatch work to a set of remote worker machines on your
27 network.  You can configure this pool via a JSON configuration file,
28 an example of which `can be found in examples <https://wannabe.guru.org/gitweb/?p=pyutils.git;a=blob_plain;f=examples/parallelize_config/.remote_worker_records;hb=HEAD>`_.
29
30 Finally, this file defines a :class:`DefaultExecutors` pool that
31 contains a pre-created and ready instance of each of the three
32 executors discussed.  It has the added benefit of being automatically
33 cleaned up at process termination time.
34
35 See instructions in :mod:`pyutils.parallelize.parallelize` for
36 setting up and using the framework.
37 """
38
39 from __future__ import annotations
40
41 import concurrent.futures as fut
42 import logging
43 import os
44 import platform
45 import random
46 import subprocess
47 import threading
48 import time
49 import warnings
50 from abc import ABC, abstractmethod
51 from collections import defaultdict
52 from dataclasses import dataclass, fields
53 from typing import Any, Callable, Dict, List, Optional, Set
54
55 import cloudpickle  # type: ignore
56 from overrides import overrides
57
58 import pyutils.typez.histogram as hist
59 from pyutils import argparse_utils, config, math_utils, persistent, string_utils
60 from pyutils.ansi import bg, fg, reset, underline
61 from pyutils.decorator_utils import singleton
62 from pyutils.exec_utils import cmd_exitcode, cmd_in_background, run_silently
63 from pyutils.parallelize.thread_utils import background_thread
64
65 logger = logging.getLogger(__name__)
66
67 parser = config.add_commandline_args(
68     f"Executors ({__file__})", "Args related to processing executors."
69 )
70 parser.add_argument(
71     '--executors_threadpool_size',
72     type=int,
73     metavar='#THREADS',
74     help='Number of threads in the default threadpool, leave unset for default',
75     default=None,
76 )
77 parser.add_argument(
78     '--executors_processpool_size',
79     type=int,
80     metavar='#PROCESSES',
81     help='Number of processes in the default processpool, leave unset for default',
82     default=None,
83 )
84 parser.add_argument(
85     '--executors_schedule_remote_backups',
86     default=True,
87     action=argparse_utils.ActionNoYes,
88     help='Should we schedule duplicative backup work if a remote bundle is slow',
89 )
90 parser.add_argument(
91     '--executors_max_bundle_failures',
92     type=int,
93     default=3,
94     metavar='#FAILURES',
95     help='Maximum number of failures before giving up on a bundle',
96 )
97 parser.add_argument(
98     '--remote_worker_records_file',
99     type=str,
100     metavar='FILENAME',
101     help='Path of the remote worker records file (JSON)',
102     default=f'{os.environ.get("HOME", ".")}/.remote_worker_records',
103 )
104 parser.add_argument(
105     '--remote_worker_helper_path',
106     type=str,
107     metavar='PATH_TO_REMOTE_WORKER_PY',
108     help='Path to remote_worker.py on remote machines',
109     default='source py39-venv/bin/activate && /home/scott/lib/release/pyutils/src/pyutils/remote_worker.py',
110 )
111
112
113 SSH = '/usr/bin/ssh -oForwardX11=no'
114 SCP = '/usr/bin/scp -C'
115
116
117 def _make_cloud_pickle(fun, *args, **kwargs):
118     """Internal helper to create cloud pickles."""
119     logger.debug("Making cloudpickled bundle at %s", fun.__name__)
120     return cloudpickle.dumps((fun, args, kwargs))
121
122
123 class BaseExecutor(ABC):
124     """The base executor interface definition.  The interface for
125     :class:`ProcessExecutor`, :class:`RemoteExecutor`, and
126     :class:`ThreadExecutor`.
127     """
128
129     def __init__(self, *, title=''):
130         """
131         Args:
132             title: the name of this executor.
133         """
134         self.title = title
135         self.histogram = hist.SimpleHistogram(
136             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
137         )
138         self.task_count = 0
139
140     @abstractmethod
141     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
142         """Submit work for the executor to do.
143
144         Args:
145             function: the Callable to be executed.
146             *args: the arguments to function
147             **kwargs: the arguments to function
148
149         Returns:
150             A concurrent :class:`Future` representing the result of the
151             work.
152         """
153         pass
154
155     @abstractmethod
156     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
157         """Shutdown the executor.
158
159         Args:
160             wait: wait for the shutdown to complete before returning?
161             quiet: keep it quiet, please.
162         """
163         pass
164
165     def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
166         """Shutdown the executor and return True if the executor is idle
167         (i.e. there are no pending or active tasks).  Return False
168         otherwise.  Note: this should only be called by the launcher
169         process.
170
171         Args:
172             quiet: keep it quiet, please.
173
174         Returns:
175             True if the executor could be shut down because it has no
176             pending work, False otherwise.
177         """
178         if self.task_count == 0:
179             self.shutdown(wait=True, quiet=quiet)
180             return True
181         return False
182
183     def adjust_task_count(self, delta: int) -> None:
184         """Change the task count.  Note: do not call this method from a
185         worker, it should only be called by the launcher process /
186         thread / machine.
187
188         Args:
189             delta: the delta value by which to adjust task count.
190         """
191         self.task_count += delta
192         logger.debug('Adjusted task count by %d to %d.', delta, self.task_count)
193
194     def get_task_count(self) -> int:
195         """Change the task count.  Note: do not call this method from a
196         worker, it should only be called by the launcher process /
197         thread / machine.
198
199         Returns:
200             The executor's current task count.
201         """
202         return self.task_count
203
204
205 class ThreadExecutor(BaseExecutor):
206     """A threadpool executor.  This executor uses Python threads to
207     schedule tasks.  Note that, at least as of python3.10, because of
208     the global lock in the interpreter itself, these do not
209     parallelize very well so this class is useful mostly for non-CPU
210     intensive tasks.
211
212     See also :class:`ProcessExecutor` and :class:`RemoteExecutor`.
213     """
214
215     def __init__(self, max_workers: Optional[int] = None):
216         """
217         Args:
218             max_workers: maximum number of threads to create in the pool.
219         """
220         super().__init__()
221         workers = None
222         if max_workers is not None:
223             workers = max_workers
224         elif 'executors_threadpool_size' in config.config:
225             workers = config.config['executors_threadpool_size']
226         if workers is not None:
227             logger.debug('Creating threadpool executor with %d workers', workers)
228         else:
229             logger.debug('Creating a default sized threadpool executor')
230         self._thread_pool_executor = fut.ThreadPoolExecutor(
231             max_workers=workers, thread_name_prefix="thread_executor_helper"
232         )
233         self.already_shutdown = False
234
235     # This is run on a different thread; do not adjust task count here.
236     @staticmethod
237     def _run_local_bundle(fun, *args, **kwargs):
238         logger.debug("Running local bundle at %s", fun.__name__)
239         result = fun(*args, **kwargs)
240         return result
241
242     @overrides
243     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
244         if self.already_shutdown:
245             raise Exception('Submitted work after shutdown.')
246         self.adjust_task_count(+1)
247         newargs = []
248         newargs.append(function)
249         for arg in args:
250             newargs.append(arg)
251         start = time.time()
252         result = self._thread_pool_executor.submit(
253             ThreadExecutor._run_local_bundle, *newargs, **kwargs
254         )
255         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
256         result.add_done_callback(lambda _: self.adjust_task_count(-1))
257         return result
258
259     @overrides
260     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
261         if not self.already_shutdown:
262             logger.debug('Shutting down threadpool executor %s', self.title)
263             self._thread_pool_executor.shutdown(wait)
264             if not quiet:
265                 print(self.histogram.__repr__(label_formatter='%ds'))
266             self.already_shutdown = True
267
268
269 class ProcessExecutor(BaseExecutor):
270     """An executor which runs tasks in child processes.
271
272     See also :class:`ThreadExecutor` and :class:`RemoteExecutor`.
273     """
274
275     def __init__(self, max_workers=None):
276         """
277         Args:
278             max_workers: the max number of worker processes to create.
279         """
280         super().__init__()
281         workers = None
282         if max_workers is not None:
283             workers = max_workers
284         elif 'executors_processpool_size' in config.config:
285             workers = config.config['executors_processpool_size']
286         if workers is not None:
287             logger.debug('Creating processpool executor with %d workers.', workers)
288         else:
289             logger.debug('Creating a default sized processpool executor')
290         self._process_executor = fut.ProcessPoolExecutor(
291             max_workers=workers,
292         )
293         self.already_shutdown = False
294
295     # This is run in another process; do not adjust task count here.
296     @staticmethod
297     def _run_cloud_pickle(pickle):
298         fun, args, kwargs = cloudpickle.loads(pickle)
299         logger.debug("Running pickled bundle at %s", fun.__name__)
300         result = fun(*args, **kwargs)
301         return result
302
303     @overrides
304     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
305         if self.already_shutdown:
306             raise Exception('Submitted work after shutdown.')
307         start = time.time()
308         self.adjust_task_count(+1)
309         pickle = _make_cloud_pickle(function, *args, **kwargs)
310         result = self._process_executor.submit(
311             ProcessExecutor._run_cloud_pickle, pickle
312         )
313         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
314         result.add_done_callback(lambda _: self.adjust_task_count(-1))
315         return result
316
317     @overrides
318     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
319         if not self.already_shutdown:
320             logger.debug('Shutting down processpool executor %s', self.title)
321             self._process_executor.shutdown(wait)
322             if not quiet:
323                 print(self.histogram.__repr__(label_formatter='%ds'))
324             self.already_shutdown = True
325
326     def __getstate__(self):
327         state = self.__dict__.copy()
328         state['_process_executor'] = None
329         return state
330
331
332 class RemoteExecutorException(Exception):
333     """Thrown when a bundle cannot be executed despite several retries."""
334
335     pass
336
337
338 @dataclass
339 class RemoteWorkerRecord:
340     """A record of info about a remote worker."""
341
342     username: str
343     """Username we can ssh into on this machine to run work."""
344
345     machine: str
346     """Machine address / name."""
347
348     weight: int
349     """Relative probability for the weighted policy to select this
350     machine for scheduling work."""
351
352     count: int
353     """If this machine is selected, what is the maximum number of task
354     that it can handle?"""
355
356     def __hash__(self):
357         return hash((self.username, self.machine))
358
359     def __repr__(self):
360         return f'{self.username}@{self.machine}'
361
362
363 @dataclass
364 class BundleDetails:
365     """All info necessary to define some unit of work that needs to be
366     done, where it is being run, its state, whether it is an original
367     bundle of a backup bundle, how many times it has failed, etc...
368     """
369
370     pickled_code: bytes
371     """The code to run, cloud pickled"""
372
373     uuid: str
374     """A unique identifier"""
375
376     function_name: str
377     """The name of the function we pickled"""
378
379     worker: Optional[RemoteWorkerRecord]
380     """The remote worker running this bundle or None if none (yet)"""
381
382     username: Optional[str]
383     """The remote username running this bundle or None if none (yet)"""
384
385     machine: Optional[str]
386     """The remote machine running this bundle or None if none (yet)"""
387
388     hostname: str
389     """The controller machine"""
390
391     code_file: str
392     """A unique filename to hold the work to be done"""
393
394     result_file: str
395     """Where the results should be placed / read from"""
396
397     pid: int
398     """The process id of the local subprocess watching the ssh connection
399     to the remote machine"""
400
401     start_ts: float
402     """Starting time"""
403
404     end_ts: float
405     """Ending time"""
406
407     slower_than_local_p95: bool
408     """Currently slower then 95% of other bundles on remote host"""
409
410     slower_than_global_p95: bool
411     """Currently slower than 95% of other bundles globally"""
412
413     src_bundle: Optional[BundleDetails]
414     """If this is a backup bundle, this points to the original bundle
415     that it's backing up.  None otherwise."""
416
417     is_cancelled: threading.Event
418     """An event that can be signaled to indicate this bundle is cancelled.
419     This is set when another copy (backup or original) of this work has
420     completed successfully elsewhere."""
421
422     was_cancelled: bool
423     """True if this bundle was cancelled, False if it finished normally"""
424
425     backup_bundles: Optional[List[BundleDetails]]
426     """If we've created backups of this bundle, this is the list of them"""
427
428     failure_count: int
429     """How many times has this bundle failed already?"""
430
431     def __repr__(self):
432         uuid = self.uuid
433         if uuid[-9:-2] == '_backup':
434             uuid = uuid[:-9]
435             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
436         else:
437             suffix = uuid[-6:]
438
439         # We colorize the uuid based on some bits from it to make them
440         # stand out in the logging and help a reader correlate log messages
441         # related to the same bundle.
442         colorz = [
443             fg('violet red'),
444             fg('red'),
445             fg('orange'),
446             fg('peach orange'),
447             fg('yellow'),
448             fg('marigold yellow'),
449             fg('green yellow'),
450             fg('tea green'),
451             fg('cornflower blue'),
452             fg('turquoise blue'),
453             fg('tropical blue'),
454             fg('lavender purple'),
455             fg('medium purple'),
456         ]
457         c = colorz[int(uuid[-2:], 16) % len(colorz)]
458         function_name = (
459             self.function_name if self.function_name is not None else 'nofname'
460         )
461         machine = self.machine if self.machine is not None else 'nomachine'
462         return f'{c}{suffix}/{function_name}/{machine}{reset()}'
463
464
465 class RemoteExecutorStatus:
466     """A status 'scoreboard' for a remote executor tracking various
467     metrics and able to render a periodic dump of global state.
468     """
469
470     def __init__(self, total_worker_count: int) -> None:
471         """
472         Args:
473             total_worker_count: number of workers in the pool
474         """
475         self.worker_count: int = total_worker_count
476         self.known_workers: Set[RemoteWorkerRecord] = set()
477         self.start_time: float = time.time()
478         self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
479         self.end_per_bundle: Dict[str, float] = defaultdict(float)
480         self.finished_bundle_timings_per_worker: Dict[
481             RemoteWorkerRecord, math_utils.NumericPopulation
482         ] = {}
483         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
484         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
485         self.finished_bundle_timings: math_utils.NumericPopulation = (
486             math_utils.NumericPopulation()
487         )
488         self.last_periodic_dump: Optional[float] = None
489         self.total_bundles_submitted: int = 0
490
491         # Protects reads and modification using self.  Also used
492         # as a memory fence for modifications to bundle.
493         self.lock: threading.Lock = threading.Lock()
494
495     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
496         """Record that bundle with uuid is assigned to a particular worker.
497
498         Args:
499             worker: the record of the worker to which uuid is assigned
500             uuid: the uuid of a bundle that has been assigned to a worker
501         """
502         with self.lock:
503             self.record_acquire_worker_already_locked(worker, uuid)
504
505     def record_acquire_worker_already_locked(
506         self, worker: RemoteWorkerRecord, uuid: str
507     ) -> None:
508         """Same as above but an entry point that doesn't acquire the lock
509         for codepaths where it's already held."""
510         assert self.lock.locked()
511         self.known_workers.add(worker)
512         self.start_per_bundle[uuid] = None
513         x = self.in_flight_bundles_by_worker.get(worker, set())
514         x.add(uuid)
515         self.in_flight_bundles_by_worker[worker] = x
516
517     def record_bundle_details(self, details: BundleDetails) -> None:
518         """Register the details about a bundle of work."""
519         with self.lock:
520             self.record_bundle_details_already_locked(details)
521
522     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
523         """Same as above but for codepaths that already hold the lock."""
524         assert self.lock.locked()
525         self.bundle_details_by_uuid[details.uuid] = details
526
527     def record_release_worker(
528         self,
529         worker: RemoteWorkerRecord,
530         uuid: str,
531         was_cancelled: bool,
532     ) -> None:
533         """Record that a bundle has released a worker."""
534         with self.lock:
535             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
536
537     def record_release_worker_already_locked(
538         self,
539         worker: RemoteWorkerRecord,
540         uuid: str,
541         was_cancelled: bool,
542     ) -> None:
543         """Same as above but for codepaths that already hold the lock."""
544         assert self.lock.locked()
545         ts = time.time()
546         self.end_per_bundle[uuid] = ts
547         self.in_flight_bundles_by_worker[worker].remove(uuid)
548         if not was_cancelled:
549             start = self.start_per_bundle[uuid]
550             assert start is not None
551             bundle_latency = ts - start
552             x = self.finished_bundle_timings_per_worker.get(
553                 worker, math_utils.NumericPopulation()
554             )
555             x.add_number(bundle_latency)
556             self.finished_bundle_timings_per_worker[worker] = x
557             self.finished_bundle_timings.add_number(bundle_latency)
558
559     def record_processing_began(self, uuid: str):
560         """Record when work on a bundle begins."""
561         with self.lock:
562             self.start_per_bundle[uuid] = time.time()
563
564     def total_in_flight(self) -> int:
565         """How many bundles are in flight currently?"""
566         assert self.lock.locked()
567         total_in_flight = 0
568         for worker in self.known_workers:
569             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
570         return total_in_flight
571
572     def total_idle(self) -> int:
573         """How many idle workers are there currently?"""
574         assert self.lock.locked()
575         return self.worker_count - self.total_in_flight()
576
577     def __repr__(self):
578         assert self.lock.locked()
579         ts = time.time()
580         total_finished = len(self.finished_bundle_timings)
581         total_in_flight = self.total_in_flight()
582         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
583         qall = None
584         if len(self.finished_bundle_timings) > 1:
585             qall_median = self.finished_bundle_timings.get_median()
586             qall_p95 = self.finished_bundle_timings.get_percentile(95)
587             ret += (
588                 f'⏱=∀p50:{qall_median:.1f}s, ∀p95:{qall_p95:.1f}s, total={ts-self.start_time:.1f}s, '
589                 f'✅={total_finished}/{self.total_bundles_submitted}, '
590                 f'💻n={total_in_flight}/{self.worker_count}\n'
591             )
592         else:
593             ret += (
594                 f'⏱={ts-self.start_time:.1f}s, '
595                 f'✅={total_finished}/{self.total_bundles_submitted}, '
596                 f'💻n={total_in_flight}/{self.worker_count}\n'
597             )
598
599         for worker in self.known_workers:
600             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
601             timings = self.finished_bundle_timings_per_worker.get(
602                 worker, math_utils.NumericPopulation()
603             )
604             count = len(timings)
605             qworker_median = None
606             qworker_p95 = None
607             if count > 1:
608                 qworker_median = timings.get_median()
609                 qworker_p95 = timings.get_percentile(95)
610                 ret += f' 💻p50: {qworker_median:.1f}s, 💻p95: {qworker_p95:.1f}s\n'
611             else:
612                 ret += '\n'
613             if count > 0:
614                 ret += f'    ...finished {count} total bundle(s) so far\n'
615             in_flight = len(self.in_flight_bundles_by_worker[worker])
616             if in_flight > 0:
617                 ret += f'    ...{in_flight} bundles currently in flight:\n'
618                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
619                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
620                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
621                     if self.start_per_bundle[bundle_uuid] is not None:
622                         sec = ts - self.start_per_bundle[bundle_uuid]
623                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
624                     else:
625                         ret += f'       {details} setting up / copying data...'
626                         sec = 0.0
627
628                     if qworker_p95 is not None:
629                         if sec > qworker_p95:
630                             ret += f'{bg("red")}>💻p95{reset()} '
631                             if details is not None:
632                                 details.slower_than_local_p95 = True
633                         else:
634                             if details is not None:
635                                 details.slower_than_local_p95 = False
636
637                     if qall is not None:
638                         if sec > qall[1]:
639                             ret += f'{bg("red")}>∀p95{reset()} '
640                             if details is not None:
641                                 details.slower_than_global_p95 = True
642                         else:
643                             details.slower_than_global_p95 = False
644                     ret += '\n'
645         return ret
646
647     def periodic_dump(self, total_bundles_submitted: int) -> None:
648         assert self.lock.locked()
649         self.total_bundles_submitted = total_bundles_submitted
650         ts = time.time()
651         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
652             print(self)
653             self.last_periodic_dump = ts
654
655
656 class RemoteWorkerSelectionPolicy(ABC):
657     """A policy for selecting a remote worker base class."""
658
659     def __init__(self):
660         self.workers: Optional[List[RemoteWorkerRecord]] = None
661
662     def register_worker_pool(self, workers: List[RemoteWorkerRecord]):
663         self.workers = workers
664
665     @abstractmethod
666     def is_worker_available(self) -> bool:
667         pass
668
669     @abstractmethod
670     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
671         pass
672
673
674 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
675     """A remote worker selector that uses weighted RNG."""
676
677     @overrides
678     def is_worker_available(self) -> bool:
679         if self.workers:
680             for worker in self.workers:
681                 if worker.count > 0:
682                     return True
683         return False
684
685     @overrides
686     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
687         grabbag = []
688         if self.workers:
689             for worker in self.workers:
690                 if worker.machine != machine_to_avoid:
691                     if worker.count > 0:
692                         for _ in range(worker.count * worker.weight):
693                             grabbag.append(worker)
694
695         if len(grabbag) == 0:
696             logger.debug(
697                 'There are no available workers that avoid %s', machine_to_avoid
698             )
699             if self.workers:
700                 for worker in self.workers:
701                     if worker.count > 0:
702                         for _ in range(worker.count * worker.weight):
703                             grabbag.append(worker)
704
705         if len(grabbag) == 0:
706             logger.warning('There are no available workers?!')
707             return None
708
709         worker = random.sample(grabbag, 1)[0]
710         assert worker.count > 0
711         worker.count -= 1
712         logger.debug('Selected worker %s', worker)
713         return worker
714
715
716 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
717     """A remote worker selector that just round robins."""
718
719     def __init__(self) -> None:
720         super().__init__()
721         self.index = 0
722
723     @overrides
724     def is_worker_available(self) -> bool:
725         if self.workers:
726             for worker in self.workers:
727                 if worker.count > 0:
728                     return True
729         return False
730
731     @overrides
732     def acquire_worker(
733         self, machine_to_avoid: str = None
734     ) -> Optional[RemoteWorkerRecord]:
735         if self.workers:
736             x = self.index
737             while True:
738                 worker = self.workers[x]
739                 if worker.count > 0:
740                     worker.count -= 1
741                     x += 1
742                     if x >= len(self.workers):
743                         x = 0
744                     self.index = x
745                     logger.debug('Selected worker %s', worker)
746                     return worker
747                 x += 1
748                 if x >= len(self.workers):
749                     x = 0
750                 if x == self.index:
751                     logger.warning('Unexpectedly could not find a worker, retrying...')
752                     return None
753         return None
754
755
756 class RemoteExecutor(BaseExecutor):
757     """An executor that uses processes on remote machines to do work.  This
758     works by creating "bundles" of work with pickled code in each to be
759     executed.  Each bundle is assigned a remote worker based on some policy
760     heuristics.  Once assigned to a remote worker, a local subprocess is
761     created.  It copies the pickled code to the remote machine via ssh/scp
762     and then starts up work on the remote machine again using ssh to invoke
763     the :file:`remote_worker.py` (`--remote_worker_helper_path`).  When
764     the work is complete, the local subprocess copies the results back to
765     the local machine via ssh/scp.
766
767     So there is essentially one "controller" machine (which may also be
768     in the remote executor pool and therefore do task work in addition to
769     controlling) and N worker machines.  This code runs on the controller
770     whereas on the worker machines we invoke pickled user code via a
771     shim in :file:`remote_worker.py`.
772
773     Some redundancy and safety provisions are made; e.g. slower than
774     expected tasks have redundant backups created and if a task fails
775     repeatedly we consider it poisoned and give up on it.
776
777     .. warning::
778
779         The network overhead / latency of copying work from the
780         controller machine to the remote workers is relatively high.
781         This executor probably only makes sense to use with
782         computationally expensive tasks such as jobs that will execute
783         for ~30 seconds or longer.
784
785     Instructions for how to set this up are provided in
786     :class:`pyutils.parallelize.parallelize`.
787
788     See also :class:`ProcessExecutor` and :class:`ThreadExecutor`.
789     """
790
791     def __init__(
792         self,
793         workers: List[RemoteWorkerRecord],
794         policy: RemoteWorkerSelectionPolicy,
795     ) -> None:
796         """
797         Args:
798             workers: A list of remote workers we can call on to do tasks.
799             policy: A policy for selecting remote workers for tasks.
800         """
801
802         super().__init__()
803         self.workers = workers
804         self.policy = policy
805         self.worker_count = 0
806         for worker in self.workers:
807             self.worker_count += worker.count
808         if self.worker_count <= 0:
809             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
810             logger.critical(msg)
811             raise RemoteExecutorException(msg)
812         self.policy.register_worker_pool(self.workers)
813         self.cv = threading.Condition()
814         logger.debug(
815             'Creating %d local threads, one per remote worker.', self.worker_count
816         )
817         self._helper_executor = fut.ThreadPoolExecutor(
818             thread_name_prefix="remote_executor_helper",
819             max_workers=self.worker_count,
820         )
821         self.status = RemoteExecutorStatus(self.worker_count)
822         self.total_bundles_submitted = 0
823         self.backup_lock = threading.Lock()
824         self.last_backup = None
825         (
826             self.heartbeat_thread,
827             self.heartbeat_stop_event,
828         ) = self._run_periodic_heartbeat()
829         self.already_shutdown = False
830
831     @background_thread
832     def _run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
833         """
834         We create a background thread to invoke :meth:`_heartbeat` regularly
835         while we are scheduling work.  It does some accounting such as
836         looking for slow bundles to tag for backup creation, checking for
837         unexpected failures, and printing a fancy message on stdout.
838         """
839         while not stop_event.is_set():
840             time.sleep(5.0)
841             logger.debug('Running periodic heartbeat code...')
842             self._heartbeat()
843         logger.debug('Periodic heartbeat thread shutting down.')
844
845     def _heartbeat(self) -> None:
846         # Note: this is invoked on a background thread, not an
847         # executor thread.  Be careful what you do with it b/c it
848         # needs to get back and dump status again periodically.
849         with self.status.lock:
850             self.status.periodic_dump(self.total_bundles_submitted)
851
852             # Look for bundles to reschedule via executor.submit
853             if config.config['executors_schedule_remote_backups']:
854                 self._maybe_schedule_backup_bundles()
855
856     def _maybe_schedule_backup_bundles(self):
857         """Maybe schedule backup bundles if we see a very slow bundle."""
858
859         assert self.status.lock.locked()
860         num_done = len(self.status.finished_bundle_timings)
861         num_idle_workers = self.worker_count - self.task_count
862         now = time.time()
863         if (
864             num_done >= 2
865             and num_idle_workers > 0
866             and (self.last_backup is None or (now - self.last_backup > 9.0))
867             and self.backup_lock.acquire(blocking=False)
868         ):
869             try:
870                 assert self.backup_lock.locked()
871
872                 bundle_to_backup = None
873                 best_score = None
874                 for (
875                     worker,
876                     bundle_uuids,
877                 ) in self.status.in_flight_bundles_by_worker.items():
878
879                     # Prefer to schedule backups of bundles running on
880                     # slower machines.
881                     base_score = 0
882                     for record in self.workers:
883                         if worker.machine == record.machine:
884                             base_score = float(record.weight)
885                             base_score = 1.0 / base_score
886                             base_score *= 200.0
887                             base_score = int(base_score)
888                             break
889
890                     for uuid in bundle_uuids:
891                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
892                         if (
893                             bundle is not None
894                             and bundle.src_bundle is None
895                             and bundle.backup_bundles is not None
896                         ):
897                             score = base_score
898
899                             # Schedule backups of bundles running
900                             # longer; especially those that are
901                             # unexpectedly slow.
902                             start_ts = self.status.start_per_bundle[uuid]
903                             if start_ts is not None:
904                                 runtime = now - start_ts
905                                 score += runtime
906                                 logger.debug(
907                                     'score[%s] => %.1f  # latency boost', bundle, score
908                                 )
909
910                                 if bundle.slower_than_local_p95:
911                                     score += runtime / 2
912                                     logger.debug(
913                                         'score[%s] => %.1f  # >worker p95',
914                                         bundle,
915                                         score,
916                                     )
917
918                                 if bundle.slower_than_global_p95:
919                                     score += runtime / 4
920                                     logger.debug(
921                                         'score[%s] => %.1f  # >global p95',
922                                         bundle,
923                                         score,
924                                     )
925
926                             # Prefer backups of bundles that don't
927                             # have backups already.
928                             backup_count = len(bundle.backup_bundles)
929                             if backup_count == 0:
930                                 score *= 2
931                             elif backup_count == 1:
932                                 score /= 2
933                             elif backup_count == 2:
934                                 score /= 8
935                             else:
936                                 score = 0
937                             logger.debug(
938                                 'score[%s] => %.1f  # {backup_count} dup backup factor',
939                                 bundle,
940                                 score,
941                             )
942
943                             if score != 0 and (
944                                 best_score is None or score > best_score
945                             ):
946                                 bundle_to_backup = bundle
947                                 assert bundle is not None
948                                 assert bundle.backup_bundles is not None
949                                 assert bundle.src_bundle is None
950                                 best_score = score
951
952                 # Note: this is all still happening on the heartbeat
953                 # runner thread.  That's ok because
954                 # _schedule_backup_for_bundle uses the executor to
955                 # submit the bundle again which will cause it to be
956                 # picked up by a worker thread and allow this thread
957                 # to return to run future heartbeats.
958                 if bundle_to_backup is not None:
959                     self.last_backup = now
960                     logger.info(
961                         '=====> SCHEDULING BACKUP %s (score=%.1f) <=====',
962                         bundle_to_backup,
963                         best_score,
964                     )
965                     self._schedule_backup_for_bundle(bundle_to_backup)
966             finally:
967                 self.backup_lock.release()
968
969     def _is_worker_available(self) -> bool:
970         """Is there a worker available currently?"""
971         return self.policy.is_worker_available()
972
973     def _acquire_worker(
974         self, machine_to_avoid: str = None
975     ) -> Optional[RemoteWorkerRecord]:
976         """Try to acquire a worker."""
977         return self.policy.acquire_worker(machine_to_avoid)
978
979     def _find_available_worker_or_block(
980         self, machine_to_avoid: str = None
981     ) -> RemoteWorkerRecord:
982         """Find a worker or block until one becomes available."""
983         with self.cv:
984             while not self._is_worker_available():
985                 self.cv.wait()
986             worker = self._acquire_worker(machine_to_avoid)
987             if worker is not None:
988                 return worker
989         msg = "We should never reach this point in the code"
990         logger.critical(msg)
991         raise Exception(msg)
992
993     def _release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
994         """Release a previously acquired worker."""
995         worker = bundle.worker
996         assert worker is not None
997         logger.debug('Released worker %s', worker)
998         self.status.record_release_worker(
999             worker,
1000             bundle.uuid,
1001             was_cancelled,
1002         )
1003         with self.cv:
1004             worker.count += 1
1005             self.cv.notify()
1006         self.adjust_task_count(-1)
1007
1008     def _check_if_cancelled(self, bundle: BundleDetails) -> bool:
1009         """See if a particular bundle is cancelled.  Do not block."""
1010         with self.status.lock:
1011             if bundle.is_cancelled.wait(timeout=0.0):
1012                 logger.debug('Bundle %s is cancelled, bail out.', bundle.uuid)
1013                 bundle.was_cancelled = True
1014                 return True
1015         return False
1016
1017     def _launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
1018         """Find a worker for bundle or block until one is available."""
1019
1020         self.adjust_task_count(+1)
1021         uuid = bundle.uuid
1022         hostname = bundle.hostname
1023         avoid_machine = override_avoid_machine
1024         is_original = bundle.src_bundle is None
1025
1026         # Try not to schedule a backup on the same host as the original.
1027         if avoid_machine is None and bundle.src_bundle is not None:
1028             avoid_machine = bundle.src_bundle.machine
1029         worker = None
1030         while worker is None:
1031             worker = self._find_available_worker_or_block(avoid_machine)
1032         assert worker is not None
1033
1034         # Ok, found a worker.
1035         bundle.worker = worker
1036         machine = bundle.machine = worker.machine
1037         username = bundle.username = worker.username
1038         self.status.record_acquire_worker(worker, uuid)
1039         logger.debug('%s: Running bundle on %s...', bundle, worker)
1040
1041         # Before we do any work, make sure the bundle is still viable.
1042         # It may have been some time between when it was submitted and
1043         # now due to lack of worker availability and someone else may
1044         # have already finished it.
1045         if self._check_if_cancelled(bundle):
1046             try:
1047                 return self._process_work_result(bundle)
1048             except Exception as e:
1049                 logger.warning(
1050                     '%s: bundle says it\'s cancelled upfront but no results?!', bundle
1051                 )
1052                 self._release_worker(bundle)
1053                 if is_original:
1054                     # Weird.  We are the original owner of this
1055                     # bundle.  For it to have been cancelled, a backup
1056                     # must have already started and completed before
1057                     # we even for started.  Moreover, the backup says
1058                     # it is done but we can't find the results it
1059                     # should have copied over.  Reschedule the whole
1060                     # thing.
1061                     logger.exception(e)
1062                     logger.error(
1063                         '%s: We are the original owner thread and yet there are '
1064                         'no results for this bundle.  This is unexpected and bad.',
1065                         bundle,
1066                     )
1067                     return self._emergency_retry_nasty_bundle(bundle)
1068                 else:
1069                     # We're a backup and our bundle is cancelled
1070                     # before we even got started.  Do nothing and let
1071                     # the original bundle's thread worry about either
1072                     # finding the results or complaining about it.
1073                     return None
1074
1075         # Send input code / data to worker machine if it's not local.
1076         if hostname not in machine:
1077             try:
1078                 cmd = (
1079                     f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
1080                 )
1081                 start_ts = time.time()
1082                 logger.info("%s: Copying work to %s via %s.", bundle, worker, cmd)
1083                 run_silently(cmd)
1084                 xfer_latency = time.time() - start_ts
1085                 logger.debug(
1086                     "%s: Copying to %s took %.1fs.", bundle, worker, xfer_latency
1087                 )
1088             except Exception as e:
1089                 self._release_worker(bundle)
1090                 if is_original:
1091                     # Weird.  We tried to copy the code to the worker
1092                     # and it failed...  And we're the original bundle.
1093                     # We have to retry.
1094                     logger.exception(e)
1095                     logger.error(
1096                         "%s: Failed to send instructions to the worker machine?! "
1097                         "This is not expected; we\'re the original bundle so this shouldn\'t "
1098                         "be a race condition.  Attempting an emergency retry...",
1099                         bundle,
1100                     )
1101                     return self._emergency_retry_nasty_bundle(bundle)
1102                 else:
1103                     # This is actually expected; we're a backup.
1104                     # There's a race condition where someone else
1105                     # already finished the work and removed the source
1106                     # code_file before we could copy it.  Ignore.
1107                     logger.warning(
1108                         '%s: Failed to send instructions to the worker machine... '
1109                         'We\'re a backup and this may be caused by the original (or '
1110                         'some other backup) already finishing this work.  Ignoring.',
1111                         bundle,
1112                     )
1113                     return None
1114
1115         # Kick off the work.  Note that if this fails we let
1116         # _wait_for_process deal with it.
1117         self.status.record_processing_began(uuid)
1118         helper_path = config.config['remote_worker_helper_path']
1119         cmd = (
1120             f'{SSH} {bundle.username}@{bundle.machine} '
1121             f'"{helper_path} --code_file {bundle.code_file} --result_file {bundle.result_file}"'
1122         )
1123         logger.debug(
1124             '%s: Executing %s in the background to kick off work...', bundle, cmd
1125         )
1126         p = cmd_in_background(cmd, silent=True)
1127         bundle.pid = p.pid
1128         logger.debug(
1129             '%s: Local ssh process pid=%d; remote worker is %s.', bundle, p.pid, machine
1130         )
1131         return self._wait_for_process(p, bundle, 0)
1132
1133     def _wait_for_process(
1134         self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
1135     ) -> Any:
1136         """At this point we've copied the bundle's pickled code to the remote
1137         worker and started an ssh process that should be invoking the
1138         remote worker to have it execute the user's code.  See how
1139         that's going and wait for it to complete or fail.  Note that
1140         this code is recursive: there are codepaths where we decide to
1141         stop waiting for an ssh process (because another backup seems
1142         to have finished) but then fail to fetch or parse the results
1143         from that backup and thus call ourselves to continue waiting
1144         on an active ssh process.  This is the purpose of the depth
1145         argument: to curtail potential infinite recursion by giving up
1146         eventually.
1147
1148         Args:
1149             p: the Popen record of the ssh job
1150             bundle: the bundle of work being executed remotely
1151             depth: how many retries we've made so far.  Starts at zero.
1152
1153         """
1154
1155         machine = bundle.machine
1156         assert p is not None
1157         pid = p.pid  # pid of the ssh process
1158         if depth > 3:
1159             logger.error(
1160                 "I've gotten repeated errors waiting on this bundle; giving up on pid=%d",
1161                 pid,
1162             )
1163             p.terminate()
1164             self._release_worker(bundle)
1165             return self._emergency_retry_nasty_bundle(bundle)
1166
1167         # Spin until either the ssh job we scheduled finishes the
1168         # bundle or some backup worker signals that they finished it
1169         # before we could.
1170         while True:
1171             try:
1172                 p.wait(timeout=0.25)
1173             except subprocess.TimeoutExpired:
1174                 if self._check_if_cancelled(bundle):
1175                     logger.info(
1176                         '%s: looks like another worker finished bundle...', bundle
1177                     )
1178                     break
1179             else:
1180                 logger.info("%s: pid %d (%s) is finished!", bundle, pid, machine)
1181                 p = None
1182                 break
1183
1184         # If we get here we believe the bundle is done; either the ssh
1185         # subprocess finished (hopefully successfully) or we noticed
1186         # that some other worker seems to have completed the bundle
1187         # before us and we're bailing out.
1188         try:
1189             ret = self._process_work_result(bundle)
1190             if ret is not None and p is not None:
1191                 p.terminate()
1192             return ret
1193
1194         # Something went wrong; e.g. we could not copy the results
1195         # back, cleanup after ourselves on the remote machine, or
1196         # unpickle the results we got from the remove machine.  If we
1197         # still have an active ssh subprocess, keep waiting on it.
1198         # Otherwise, time for an emergency reschedule.
1199         except Exception as e:
1200             logger.exception(e)
1201             logger.error('%s: Something unexpected just happened...', bundle)
1202             if p is not None:
1203                 logger.warning(
1204                     "%s: Failed to wrap up \"done\" bundle, re-waiting on active ssh.",
1205                     bundle,
1206                 )
1207                 return self._wait_for_process(p, bundle, depth + 1)
1208             else:
1209                 self._release_worker(bundle)
1210                 return self._emergency_retry_nasty_bundle(bundle)
1211
1212     def _process_work_result(self, bundle: BundleDetails) -> Any:
1213         """A bundle seems to be completed.  Check on the results."""
1214
1215         with self.status.lock:
1216             is_original = bundle.src_bundle is None
1217             was_cancelled = bundle.was_cancelled
1218             username = bundle.username
1219             machine = bundle.machine
1220             result_file = bundle.result_file
1221             code_file = bundle.code_file
1222
1223             # Whether original or backup, if we finished first we must
1224             # fetch the results if the computation happened on a
1225             # remote machine.
1226             bundle.end_ts = time.time()
1227             if not was_cancelled:
1228                 assert bundle.machine is not None
1229                 if bundle.hostname not in bundle.machine:
1230                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
1231                     logger.info(
1232                         "%s: Fetching results back from %s@%s via %s",
1233                         bundle,
1234                         username,
1235                         machine,
1236                         cmd,
1237                     )
1238
1239                     # If either of these throw they are handled in
1240                     # _wait_for_process.
1241                     attempts = 0
1242                     while True:
1243                         try:
1244                             run_silently(cmd)
1245                         except Exception as e:
1246                             attempts += 1
1247                             if attempts >= 3:
1248                                 raise e
1249                         else:
1250                             break
1251
1252                     # Cleanup remote /tmp files.
1253                     run_silently(
1254                         f'{SSH} {username}@{machine}'
1255                         f' "/bin/rm -f {code_file} {result_file}"'
1256                     )
1257                     logger.debug(
1258                         'Fetching results back took %.2fs', time.time() - bundle.end_ts
1259                     )
1260                 dur = bundle.end_ts - bundle.start_ts
1261                 self.histogram.add_item(dur)
1262
1263         # Only the original worker should unpickle the file contents
1264         # though since it's the only one whose result matters.  The
1265         # original is also the only job that may delete result_file
1266         # from disk.  Note that the original may have been cancelled
1267         # if one of the backups finished first; it still must read the
1268         # result from disk.  It still does that here with is_cancelled
1269         # set.
1270         if is_original:
1271             logger.debug("%s: Unpickling %s.", bundle, result_file)
1272             try:
1273                 with open(result_file, 'rb') as rb:
1274                     serialized = rb.read()
1275                 result = cloudpickle.loads(serialized)
1276             except Exception as e:
1277                 logger.exception(e)
1278                 logger.error('Failed to load %s... this is bad news.', result_file)
1279                 self._release_worker(bundle)
1280
1281                 # Re-raise the exception; the code in _wait_for_process may
1282                 # decide to _emergency_retry_nasty_bundle here.
1283                 raise e
1284             logger.debug('Removing local (master) %s and %s.', code_file, result_file)
1285             os.remove(result_file)
1286             os.remove(code_file)
1287
1288             # Notify any backups that the original is done so they
1289             # should stop ASAP.  Do this whether or not we
1290             # finished first since there could be more than one
1291             # backup.
1292             if bundle.backup_bundles is not None:
1293                 for backup in bundle.backup_bundles:
1294                     logger.debug(
1295                         '%s: Notifying backup %s that it\'s cancelled',
1296                         bundle,
1297                         backup.uuid,
1298                     )
1299                     backup.is_cancelled.set()
1300
1301         # This is a backup job and, by now, we have already fetched
1302         # the bundle results.
1303         else:
1304             # Backup results don't matter, they just need to leave the
1305             # result file in the right place for their originals to
1306             # read/unpickle later.
1307             result = None
1308
1309             # Tell the original to stop if we finished first.
1310             if not was_cancelled:
1311                 orig_bundle = bundle.src_bundle
1312                 assert orig_bundle is not None
1313                 logger.debug(
1314                     '%s: Notifying original %s we beat them to it.',
1315                     bundle,
1316                     orig_bundle.uuid,
1317                 )
1318                 orig_bundle.is_cancelled.set()
1319         self._release_worker(bundle, was_cancelled=was_cancelled)
1320         return result
1321
1322     def _create_original_bundle(self, pickle, function_name: str):
1323         """Creates a bundle that is not a backup of any other bundle but
1324         rather represents a user task.
1325         """
1326
1327         uuid = string_utils.generate_uuid(omit_dashes=True)
1328         code_file = f'/tmp/{uuid}.code.bin'
1329         result_file = f'/tmp/{uuid}.result.bin'
1330
1331         logger.debug('Writing pickled code to %s', code_file)
1332         with open(code_file, 'wb') as wb:
1333             wb.write(pickle)
1334
1335         bundle = BundleDetails(
1336             pickled_code=pickle,
1337             uuid=uuid,
1338             function_name=function_name,
1339             worker=None,
1340             username=None,
1341             machine=None,
1342             hostname=platform.node(),
1343             code_file=code_file,
1344             result_file=result_file,
1345             pid=0,
1346             start_ts=time.time(),
1347             end_ts=0.0,
1348             slower_than_local_p95=False,
1349             slower_than_global_p95=False,
1350             src_bundle=None,
1351             is_cancelled=threading.Event(),
1352             was_cancelled=False,
1353             backup_bundles=[],
1354             failure_count=0,
1355         )
1356         self.status.record_bundle_details(bundle)
1357         logger.debug('%s: Created an original bundle', bundle)
1358         return bundle
1359
1360     def _create_backup_bundle(self, src_bundle: BundleDetails):
1361         """Creates a bundle that is a backup of another bundle that is
1362         running too slowly."""
1363
1364         assert self.status.lock.locked()
1365         assert src_bundle.backup_bundles is not None
1366         n = len(src_bundle.backup_bundles)
1367         uuid = src_bundle.uuid + f'_backup#{n}'
1368
1369         backup_bundle = BundleDetails(
1370             pickled_code=src_bundle.pickled_code,
1371             uuid=uuid,
1372             function_name=src_bundle.function_name,
1373             worker=None,
1374             username=None,
1375             machine=None,
1376             hostname=src_bundle.hostname,
1377             code_file=src_bundle.code_file,
1378             result_file=src_bundle.result_file,
1379             pid=0,
1380             start_ts=time.time(),
1381             end_ts=0.0,
1382             slower_than_local_p95=False,
1383             slower_than_global_p95=False,
1384             src_bundle=src_bundle,
1385             is_cancelled=threading.Event(),
1386             was_cancelled=False,
1387             backup_bundles=None,  # backup backups not allowed
1388             failure_count=0,
1389         )
1390         src_bundle.backup_bundles.append(backup_bundle)
1391         self.status.record_bundle_details_already_locked(backup_bundle)
1392         logger.debug('%s: Created a backup bundle', backup_bundle)
1393         return backup_bundle
1394
1395     def _schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1396         """Schedule a backup of src_bundle."""
1397
1398         assert self.status.lock.locked()
1399         assert src_bundle is not None
1400         backup_bundle = self._create_backup_bundle(src_bundle)
1401         logger.debug(
1402             '%s/%s: Scheduling backup for execution...',
1403             backup_bundle.uuid,
1404             backup_bundle.function_name,
1405         )
1406         self._helper_executor.submit(self._launch, backup_bundle)
1407
1408         # Results from backups don't matter; if they finish first
1409         # they will move the result_file to this machine and let
1410         # the original pick them up and unpickle them (and return
1411         # a result).
1412
1413     def _emergency_retry_nasty_bundle(
1414         self, bundle: BundleDetails
1415     ) -> Optional[fut.Future]:
1416         """Something unexpectedly failed with bundle.  Either retry it
1417         from the beginning or throw in the towel and give up on it."""
1418
1419         is_original = bundle.src_bundle is None
1420         bundle.worker = None
1421         avoid_last_machine = bundle.machine
1422         bundle.machine = None
1423         bundle.username = None
1424         bundle.failure_count += 1
1425         if is_original:
1426             retry_limit = 3
1427         else:
1428             retry_limit = 2
1429
1430         if bundle.failure_count > retry_limit:
1431             logger.error(
1432                 '%s: Tried this bundle too many times already (%dx); giving up.',
1433                 bundle,
1434                 retry_limit,
1435             )
1436             if is_original:
1437                 raise RemoteExecutorException(
1438                     f'{bundle}: This bundle can\'t be completed despite several backups and retries',
1439                 )
1440             else:
1441                 logger.error(
1442                     '%s: At least it\'s only a backup; better luck with the others.',
1443                     bundle,
1444                 )
1445             return None
1446         else:
1447             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1448             logger.warning(msg)
1449             warnings.warn(msg)
1450             return self._launch(bundle, avoid_last_machine)
1451
1452     @overrides
1453     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1454         """Submit work to be done.  This is the user entry point of this
1455         class."""
1456         if self.already_shutdown:
1457             raise Exception('Submitted work after shutdown.')
1458         pickle = _make_cloud_pickle(function, *args, **kwargs)
1459         bundle = self._create_original_bundle(pickle, function.__name__)
1460         self.total_bundles_submitted += 1
1461         return self._helper_executor.submit(self._launch, bundle)
1462
1463     @overrides
1464     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
1465         """Shutdown the executor."""
1466         if not self.already_shutdown:
1467             logging.debug('Shutting down RemoteExecutor %s', self.title)
1468             self.heartbeat_stop_event.set()
1469             self.heartbeat_thread.join()
1470             self._helper_executor.shutdown(wait)
1471             if not quiet:
1472                 print(self.histogram.__repr__(label_formatter='%ds'))
1473             self.already_shutdown = True
1474
1475
1476 class RemoteWorkerPoolProvider:
1477     @abstractmethod
1478     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
1479         pass
1480
1481
1482 @persistent.persistent_autoloaded_singleton()  # type: ignore
1483 class ConfigRemoteWorkerPoolProvider(
1484     RemoteWorkerPoolProvider, persistent.JsonFileBasedPersistent
1485 ):
1486     def __init__(self, json_remote_worker_pool: Dict[str, Any]):
1487         self.remote_worker_pool = []
1488         for record in json_remote_worker_pool['remote_worker_records']:
1489             self.remote_worker_pool.append(
1490                 self.dataclassFromDict(RemoteWorkerRecord, record)
1491             )
1492         assert len(self.remote_worker_pool) > 0
1493
1494     @staticmethod
1495     def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any:
1496         fieldSet = {f.name for f in fields(clsName) if f.init}
1497         filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
1498         return clsName(**filteredArgDict)
1499
1500     @overrides
1501     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
1502         return self.remote_worker_pool
1503
1504     @overrides
1505     def get_persistent_data(self) -> List[RemoteWorkerRecord]:
1506         return self.remote_worker_pool
1507
1508     @staticmethod
1509     @overrides
1510     def get_filename() -> str:
1511         return config.config['remote_worker_records_file']
1512
1513     @staticmethod
1514     @overrides
1515     def should_we_load_data(filename: str) -> bool:
1516         return True
1517
1518     @staticmethod
1519     @overrides
1520     def should_we_save_data(filename: str) -> bool:
1521         return False
1522
1523
1524 @singleton
1525 class DefaultExecutors(object):
1526     """A container for a default thread, process and remote executor.
1527     These are not created until needed and we take care to clean up
1528     before process exit automatically for the caller's convenience.
1529     Instead of creating your own executor, consider using the one
1530     from this pool.  e.g.::
1531
1532         @par.parallelize(method=par.Method.PROCESS)
1533         def do_work(
1534             solutions: List[Work],
1535             shard_num: int,
1536             ...
1537         ):
1538             <do the work>
1539
1540
1541         def start_do_work(all_work: List[Work]):
1542             shards = []
1543             logger.debug('Sharding work into groups of 10.')
1544             for subset in list_utils.shard(all_work, 10):
1545                 shards.append([x for x in subset])
1546
1547             logger.debug('Kicking off helper pool.')
1548             try:
1549                 for n, shard in enumerate(shards):
1550                     results.append(
1551                         do_work(
1552                             shard, n, shared_cache.get_name(), max_letter_pop_per_word
1553                         )
1554                     )
1555                 smart_future.wait_all(results)
1556             finally:
1557                 # Note: if you forget to do this it will clean itself up
1558                 # during program termination including tearing down any
1559                 # active ssh connections.
1560                 executors.DefaultExecutors().process_pool().shutdown()
1561     """
1562
1563     def __init__(self):
1564         self.thread_executor: Optional[ThreadExecutor] = None
1565         self.process_executor: Optional[ProcessExecutor] = None
1566         self.remote_executor: Optional[RemoteExecutor] = None
1567
1568     @staticmethod
1569     def _ping(host) -> bool:
1570         logger.debug('RUN> ping -c 1 %s', host)
1571         try:
1572             x = cmd_exitcode(
1573                 f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0
1574             )
1575             return x == 0
1576         except Exception:
1577             return False
1578
1579     def thread_pool(self) -> ThreadExecutor:
1580         if self.thread_executor is None:
1581             self.thread_executor = ThreadExecutor()
1582         return self.thread_executor
1583
1584     def process_pool(self) -> ProcessExecutor:
1585         if self.process_executor is None:
1586             self.process_executor = ProcessExecutor()
1587         return self.process_executor
1588
1589     def remote_pool(self) -> RemoteExecutor:
1590         if self.remote_executor is None:
1591             logger.info('Looking for some helper machines...')
1592             provider = ConfigRemoteWorkerPoolProvider()
1593             all_machines = provider.get_remote_workers()
1594             pool = []
1595
1596             # Make sure we can ping each machine.
1597             for record in all_machines:
1598                 if self._ping(record.machine):
1599                     logger.info('%s is alive / responding to pings', record.machine)
1600                     pool.append(record)
1601
1602             # The controller machine has a lot to do; go easy on it.
1603             for record in pool:
1604                 if record.machine == platform.node() and record.count > 1:
1605                     logger.info('Reducing workload for %s.', record.machine)
1606                     record.count = max(int(record.count / 2), 1)
1607
1608             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1609             policy.register_worker_pool(pool)
1610             self.remote_executor = RemoteExecutor(pool, policy)
1611         return self.remote_executor
1612
1613     def shutdown(self) -> None:
1614         if self.thread_executor is not None:
1615             self.thread_executor.shutdown(wait=True, quiet=True)
1616             self.thread_executor = None
1617         if self.process_executor is not None:
1618             self.process_executor.shutdown(wait=True, quiet=True)
1619             self.process_executor = None
1620         if self.remote_executor is not None:
1621             self.remote_executor.shutdown(wait=True, quiet=True)
1622             self.remote_executor = None