Simple base interfaces.
[pyutils.git] / src / pyutils / parallelize / executors.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 # pylint: disable=too-many-instance-attributes
4 # pylint: disable=too-many-nested-blocks
5
6 # © Copyright 2021-2023, Scott Gasch
7
8 """
9 This module defines a :class:`BaseExecutor` interface and three
10 implementations:
11
12     - :class:`ThreadExecutor`
13     - :class:`ProcessExecutor`
14     - :class:`RemoteExecutor`
15
16 The :class:`ThreadExecutor` is used to dispatch work to background
17 threads in the same Python process for parallelized work.  Of course,
18 until the Global Interpreter Lock (GIL) bottleneck is resolved, this
19 is not terribly useful for compute-bound code.  But it's good for
20 work that is mostly I/O bound.
21
22 The :class:`ProcessExecutor` is used to dispatch work to other
23 processes on the same machine and is more useful for compute-bound
24 workloads.
25
26 The :class:`RemoteExecutor` is used in conjunection with `ssh`,
27 the `cloudpickle` dependency, and `remote_worker.py <https://wannabe.guru.org/gitweb/?p=pyutils.git;a=blob_plain;f=src/pyutils/remote_worker.py;hb=HEAD>`_ file
28 to dispatch work to a set of remote worker machines on your
29 network.  You can configure this pool via a JSON configuration file,
30 an example of which `can be found in examples <https://wannabe.guru.org/gitweb/?p=pyutils.git;a=blob_plain;f=examples/parallelize_config/.remote_worker_records;hb=HEAD>`_.
31
32 Finally, this file defines a :class:`DefaultExecutors` pool that
33 contains a pre-created and ready instance of each of the three
34 executors discussed.  It has the added benefit of being automatically
35 cleaned up at process termination time.
36
37 See instructions in :mod:`pyutils.parallelize.parallelize` for
38 setting up and using the framework.
39 """
40
41 from __future__ import annotations
42
43 import concurrent.futures as fut
44 import logging
45 import os
46 import platform
47 import random
48 import subprocess
49 import threading
50 import time
51 import warnings
52 from abc import ABC, abstractmethod
53 from collections import defaultdict
54 from dataclasses import dataclass
55 from typing import Any, Callable, Dict, List, Optional, Set
56
57 import cloudpickle  # type: ignore
58 from overrides import overrides
59
60 import pyutils.typez.histogram as hist
61 from pyutils import (
62     argparse_utils,
63     config,
64     dataclass_utils,
65     math_utils,
66     persistent,
67     string_utils,
68 )
69 from pyutils.ansi import bg, fg, reset, underline
70 from pyutils.decorator_utils import singleton
71 from pyutils.exec_utils import cmd_exitcode, cmd_in_background, run_silently
72 from pyutils.parallelize.thread_utils import background_thread
73 from pyutils.typez import type_utils
74
75 logger = logging.getLogger(__name__)
76
77 parser = config.add_commandline_args(
78     f"Executors ({__file__})", "Args related to processing executors."
79 )
80 parser.add_argument(
81     '--executors_threadpool_size',
82     type=int,
83     metavar='#THREADS',
84     help='Number of threads in the default threadpool, leave unset for default',
85     default=None,
86 )
87 parser.add_argument(
88     '--executors_processpool_size',
89     type=int,
90     metavar='#PROCESSES',
91     help='Number of processes in the default processpool, leave unset for default',
92     default=None,
93 )
94 parser.add_argument(
95     '--executors_schedule_remote_backups',
96     default=True,
97     action=argparse_utils.ActionNoYes,
98     help='Should we schedule duplicative backup work if a remote bundle is slow',
99 )
100 parser.add_argument(
101     '--executors_max_bundle_failures',
102     type=int,
103     default=3,
104     metavar='#FAILURES',
105     help='Maximum number of failures before giving up on a bundle',
106 )
107 parser.add_argument(
108     '--remote_worker_records_file',
109     type=str,
110     metavar='FILENAME',
111     help='Path of the remote worker records file (JSON)',
112     default=f'{os.environ.get("HOME", ".")}/.remote_worker_records',
113 )
114 parser.add_argument(
115     '--remote_worker_helper_path',
116     type=str,
117     metavar='PATH_TO_REMOTE_WORKER_PY',
118     help='Path to remote_worker.py on remote machines',
119     default=f'source py39-venv/bin/activate && {os.environ["HOME"]}/pyutils/src/pyutils/remote_worker.py',
120 )
121
122
123 SSH = '/usr/bin/ssh -oForwardX11=no'
124 SCP = '/usr/bin/scp -C'
125
126
127 def _make_cloud_pickle(fun, *args, **kwargs):
128     """Internal helper to create cloud pickles."""
129     logger.debug("Making cloudpickled bundle at %s", fun.__name__)
130     return cloudpickle.dumps((fun, args, kwargs))
131
132
133 class BaseExecutor(ABC):
134     """The base executor interface definition.  The interface for
135     :class:`ProcessExecutor`, :class:`RemoteExecutor`, and
136     :class:`ThreadExecutor`.
137     """
138
139     def __init__(self, *, title=''):
140         """
141         Args:
142             title: the name of this executor.
143         """
144         self.title = title
145         self.histogram = hist.SimpleHistogram(
146             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
147         )
148         self.task_count = 0
149
150     @abstractmethod
151     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
152         """Submit work for the executor to do.
153
154         Args:
155             function: the Callable to be executed.
156             *args: the arguments to function
157             **kwargs: the arguments to function
158
159         Returns:
160             A concurrent :class:`Future` representing the result of the
161             work.
162         """
163         pass
164
165     @abstractmethod
166     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
167         """Shutdown the executor.
168
169         Args:
170             wait: wait for the shutdown to complete before returning?
171             quiet: keep it quiet, please.
172         """
173         pass
174
175     def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
176         """Shutdown the executor and return True if the executor is idle
177         (i.e. there are no pending or active tasks).  Return False
178         otherwise.  Note: this should only be called by the launcher
179         process.
180
181         Args:
182             quiet: keep it quiet, please.
183
184         Returns:
185             True if the executor could be shut down because it has no
186             pending work, False otherwise.
187         """
188         if self.task_count == 0:
189             self.shutdown(wait=True, quiet=quiet)
190             return True
191         return False
192
193     def adjust_task_count(self, delta: int) -> None:
194         """Change the task count.  Note: do not call this method from a
195         worker, it should only be called by the launcher process /
196         thread / machine.
197
198         Args:
199             delta: the delta value by which to adjust task count.
200         """
201         self.task_count += delta
202         logger.debug('Adjusted task count by %d to %d.', delta, self.task_count)
203
204     def get_task_count(self) -> int:
205         """Change the task count.  Note: do not call this method from a
206         worker, it should only be called by the launcher process /
207         thread / machine.
208
209         Returns:
210             The executor's current task count.
211         """
212         return self.task_count
213
214
215 class ThreadExecutor(BaseExecutor):
216     """A threadpool executor.  This executor uses Python threads to
217     schedule tasks.  Note that, at least as of python3.10, because of
218     the global lock in the interpreter itself, these do not
219     parallelize very well so this class is useful mostly for non-CPU
220     intensive tasks.
221
222     See also :class:`ProcessExecutor` and :class:`RemoteExecutor`.
223     """
224
225     def __init__(self, max_workers: Optional[int] = None):
226         """
227         Args:
228             max_workers: maximum number of threads to create in the pool.
229         """
230         super().__init__()
231         workers = None
232         if max_workers is not None:
233             workers = max_workers
234         elif 'executors_threadpool_size' in config.config:
235             workers = config.config['executors_threadpool_size']
236         if workers is not None:
237             logger.debug('Creating threadpool executor with %d workers', workers)
238         else:
239             logger.debug('Creating a default sized threadpool executor')
240         self._thread_pool_executor = fut.ThreadPoolExecutor(
241             max_workers=workers, thread_name_prefix="thread_executor_helper"
242         )
243         self.already_shutdown = False
244
245     # This is run on a different thread; do not adjust task count here.
246     @staticmethod
247     def _run_local_bundle(fun, *args, **kwargs):
248         logger.debug("Running local bundle at %s", fun.__name__)
249         result = fun(*args, **kwargs)
250         return result
251
252     @overrides
253     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
254         if self.already_shutdown:
255             raise Exception('Submitted work after shutdown.')
256         self.adjust_task_count(+1)
257         newargs = []
258         newargs.append(function)
259         for arg in args:
260             newargs.append(arg)
261         start = time.time()
262         result = self._thread_pool_executor.submit(
263             ThreadExecutor._run_local_bundle, *newargs, **kwargs
264         )
265         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
266         result.add_done_callback(lambda _: self.adjust_task_count(-1))
267         return result
268
269     @overrides
270     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
271         if not self.already_shutdown:
272             logger.debug('Shutting down threadpool executor %s', self.title)
273             self._thread_pool_executor.shutdown(wait)
274             if not quiet:
275                 print(self.histogram.__repr__(label_formatter='%ds'))
276             self.already_shutdown = True
277
278
279 class ProcessExecutor(BaseExecutor):
280     """An executor which runs tasks in child processes.
281
282     See also :class:`ThreadExecutor` and :class:`RemoteExecutor`.
283     """
284
285     def __init__(self, max_workers=None):
286         """
287         Args:
288             max_workers: the max number of worker processes to create.
289         """
290         super().__init__()
291         workers = None
292         if max_workers is not None:
293             workers = max_workers
294         elif 'executors_processpool_size' in config.config:
295             workers = config.config['executors_processpool_size']
296         if workers is not None:
297             logger.debug('Creating processpool executor with %d workers.', workers)
298         else:
299             logger.debug('Creating a default sized processpool executor')
300         self._process_executor = fut.ProcessPoolExecutor(
301             max_workers=workers,
302         )
303         self.already_shutdown = False
304
305     # This is run in another process; do not adjust task count here.
306     @staticmethod
307     def _run_cloud_pickle(pickle):
308         fun, args, kwargs = cloudpickle.loads(pickle)
309         logger.debug("Running pickled bundle at %s", fun.__name__)
310         result = fun(*args, **kwargs)
311         return result
312
313     @overrides
314     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
315         if self.already_shutdown:
316             raise Exception('Submitted work after shutdown.')
317         start = time.time()
318         self.adjust_task_count(+1)
319         pickle = _make_cloud_pickle(function, *args, **kwargs)
320         result = self._process_executor.submit(
321             ProcessExecutor._run_cloud_pickle, pickle
322         )
323         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
324         result.add_done_callback(lambda _: self.adjust_task_count(-1))
325         return result
326
327     @overrides
328     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
329         if not self.already_shutdown:
330             logger.debug('Shutting down processpool executor %s', self.title)
331             self._process_executor.shutdown(wait)
332             if not quiet:
333                 print(self.histogram.__repr__(label_formatter='%ds'))
334             self.already_shutdown = True
335
336     def __getstate__(self):
337         state = self.__dict__.copy()
338         state['_process_executor'] = None
339         return state
340
341
342 class RemoteExecutorException(Exception):
343     """Thrown when a bundle cannot be executed despite several retries."""
344
345     pass
346
347
348 @dataclass
349 class RemoteWorkerRecord:
350     """A record of info about a remote worker."""
351
352     username: str
353     """Username we can ssh into on this machine to run work."""
354
355     machine: str
356     """Machine address / name."""
357
358     weight: int
359     """Relative probability for the weighted policy to select this
360     machine for scheduling work."""
361
362     count: int
363     """If this machine is selected, what is the maximum number of task
364     that it can handle?"""
365
366     def __hash__(self):
367         return hash((self.username, self.machine))
368
369     def __repr__(self):
370         return f'{self.username}@{self.machine}'
371
372
373 @dataclass
374 class BundleDetails:
375     """All info necessary to define some unit of work that needs to be
376     done, where it is being run, its state, whether it is an original
377     bundle of a backup bundle, how many times it has failed, etc...
378     """
379
380     pickled_code: bytes
381     """The code to run, cloud pickled"""
382
383     uuid: str
384     """A unique identifier"""
385
386     function_name: str
387     """The name of the function we pickled"""
388
389     worker: Optional[RemoteWorkerRecord]
390     """The remote worker running this bundle or None if none (yet)"""
391
392     username: Optional[str]
393     """The remote username running this bundle or None if none (yet)"""
394
395     machine: Optional[str]
396     """The remote machine running this bundle or None if none (yet)"""
397
398     controller: str
399     """The controller machine"""
400
401     code_file: str
402     """A unique filename to hold the work to be done"""
403
404     result_file: str
405     """Where the results should be placed / read from"""
406
407     pid: int
408     """The process id of the local subprocess watching the ssh connection
409     to the remote machine"""
410
411     start_ts: float
412     """Starting time"""
413
414     end_ts: float
415     """Ending time"""
416
417     slower_than_local_p95: bool
418     """Currently slower then 95% of other bundles on remote host"""
419
420     slower_than_global_p95: bool
421     """Currently slower than 95% of other bundles globally"""
422
423     src_bundle: Optional[BundleDetails]
424     """If this is a backup bundle, this points to the original bundle
425     that it's backing up.  None otherwise."""
426
427     is_cancelled: threading.Event
428     """An event that can be signaled to indicate this bundle is cancelled.
429     This is set when another copy (backup or original) of this work has
430     completed successfully elsewhere."""
431
432     was_cancelled: bool
433     """True if this bundle was cancelled, False if it finished normally"""
434
435     backup_bundles: Optional[List[BundleDetails]]
436     """If we've created backups of this bundle, this is the list of them"""
437
438     failure_count: int
439     """How many times has this bundle failed already?"""
440
441     def __repr__(self):
442         uuid = self.uuid
443         if uuid[-9:-2] == '_backup':
444             uuid = uuid[:-9]
445             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
446         else:
447             suffix = uuid[-6:]
448
449         # We colorize the uuid based on some bits from it to make them
450         # stand out in the logging and help a reader correlate log messages
451         # related to the same bundle.
452         colorz = [
453             fg('violet red'),
454             fg('red'),
455             fg('orange'),
456             fg('peach orange'),
457             fg('yellow'),
458             fg('marigold yellow'),
459             fg('green yellow'),
460             fg('tea green'),
461             fg('cornflower blue'),
462             fg('turquoise blue'),
463             fg('tropical blue'),
464             fg('lavender purple'),
465             fg('medium purple'),
466         ]
467         c = colorz[int(uuid[-2:], 16) % len(colorz)]
468         function_name = (
469             self.function_name if self.function_name is not None else 'nofname'
470         )
471         machine = self.machine if self.machine is not None else 'nomachine'
472         return f'{c}{suffix}/{function_name}/{machine}{reset()}'
473
474
475 class RemoteExecutorStatus:
476     """A status 'scoreboard' for a remote executor tracking various
477     metrics and able to render a periodic dump of global state.
478     """
479
480     def __init__(self, total_worker_count: int) -> None:
481         """
482         Args:
483             total_worker_count: number of workers in the pool
484         """
485         self.worker_count: int = total_worker_count
486         self.known_workers: Set[RemoteWorkerRecord] = set()
487         self.start_time: float = time.time()
488         self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
489         self.end_per_bundle: Dict[str, float] = defaultdict(float)
490         self.finished_bundle_timings_per_worker: Dict[
491             RemoteWorkerRecord, math_utils.NumericPopulation
492         ] = {}
493         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
494         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
495         self.finished_bundle_timings: math_utils.NumericPopulation = (
496             math_utils.NumericPopulation()
497         )
498         self.last_periodic_dump: Optional[float] = None
499         self.total_bundles_submitted: int = 0
500
501         # Protects reads and modification using self.  Also used
502         # as a memory fence for modifications to bundle.
503         self.lock: threading.Lock = threading.Lock()
504
505     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
506         """Record that bundle with uuid is assigned to a particular worker.
507
508         Args:
509             worker: the record of the worker to which uuid is assigned
510             uuid: the uuid of a bundle that has been assigned to a worker
511         """
512         with self.lock:
513             self.record_acquire_worker_already_locked(worker, uuid)
514
515     def record_acquire_worker_already_locked(
516         self, worker: RemoteWorkerRecord, uuid: str
517     ) -> None:
518         """Same as above but an entry point that doesn't acquire the lock
519         for codepaths where it's already held."""
520         assert self.lock.locked()
521         self.known_workers.add(worker)
522         self.start_per_bundle[uuid] = None
523         x = self.in_flight_bundles_by_worker.get(worker, set())
524         x.add(uuid)
525         self.in_flight_bundles_by_worker[worker] = x
526
527     def record_bundle_details(self, details: BundleDetails) -> None:
528         """Register the details about a bundle of work."""
529         with self.lock:
530             self.record_bundle_details_already_locked(details)
531
532     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
533         """Same as above but for codepaths that already hold the lock."""
534         assert self.lock.locked()
535         self.bundle_details_by_uuid[details.uuid] = details
536
537     def record_release_worker(
538         self,
539         worker: RemoteWorkerRecord,
540         uuid: str,
541         was_cancelled: bool,
542     ) -> None:
543         """Record that a bundle has released a worker."""
544         with self.lock:
545             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
546
547     def record_release_worker_already_locked(
548         self,
549         worker: RemoteWorkerRecord,
550         uuid: str,
551         was_cancelled: bool,
552     ) -> None:
553         """Same as above but for codepaths that already hold the lock."""
554         assert self.lock.locked()
555         ts = time.time()
556         self.end_per_bundle[uuid] = ts
557         self.in_flight_bundles_by_worker[worker].remove(uuid)
558         if not was_cancelled:
559             start = self.start_per_bundle[uuid]
560             assert start is not None
561             bundle_latency = ts - start
562             x = self.finished_bundle_timings_per_worker.get(
563                 worker, math_utils.NumericPopulation()
564             )
565             x.add_number(bundle_latency)
566             self.finished_bundle_timings_per_worker[worker] = x
567             self.finished_bundle_timings.add_number(bundle_latency)
568
569     def record_processing_began(self, uuid: str):
570         """Record when work on a bundle begins."""
571         with self.lock:
572             self.start_per_bundle[uuid] = time.time()
573
574     def total_in_flight(self) -> int:
575         """How many bundles are in flight currently?"""
576         assert self.lock.locked()
577         total_in_flight = 0
578         for worker in self.known_workers:
579             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
580         return total_in_flight
581
582     def total_idle(self) -> int:
583         """How many idle workers are there currently?"""
584         assert self.lock.locked()
585         return self.worker_count - self.total_in_flight()
586
587     def __repr__(self):
588         assert self.lock.locked()
589         ts = time.time()
590         total_finished = len(self.finished_bundle_timings)
591         total_in_flight = self.total_in_flight()
592         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
593         qall_median = None
594         qall_p95 = None
595         if len(self.finished_bundle_timings) > 1:
596             qall_median = self.finished_bundle_timings.get_median()
597             qall_p95 = self.finished_bundle_timings.get_percentile(95)
598             ret += (
599                 f'⏱=∀p50:{qall_median:.1f}s, ∀p95:{qall_p95:.1f}s, total={ts-self.start_time:.1f}s, '
600                 f'✅={total_finished}/{self.total_bundles_submitted}, '
601                 f'💻n={total_in_flight}/{self.worker_count}\n'
602             )
603         else:
604             ret += (
605                 f'⏱={ts-self.start_time:.1f}s, '
606                 f'✅={total_finished}/{self.total_bundles_submitted}, '
607                 f'💻n={total_in_flight}/{self.worker_count}\n'
608             )
609
610         for worker in self.known_workers:
611             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
612             timings = self.finished_bundle_timings_per_worker.get(
613                 worker, math_utils.NumericPopulation()
614             )
615             count = len(timings)
616             qworker_median = None
617             qworker_p95 = None
618             if count > 1:
619                 qworker_median = timings.get_median()
620                 qworker_p95 = timings.get_percentile(95)
621                 ret += f' 💻p50: {qworker_median:.1f}s, 💻p95: {qworker_p95:.1f}s\n'
622             else:
623                 ret += '\n'
624             if count > 0:
625                 ret += f'    ...finished {count} total bundle(s) so far\n'
626             in_flight = len(self.in_flight_bundles_by_worker[worker])
627             if in_flight > 0:
628                 ret += f'    ...{in_flight} bundles currently in flight:\n'
629                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
630                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
631                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
632                     if self.start_per_bundle[bundle_uuid] is not None:
633                         sec = ts - self.start_per_bundle[bundle_uuid]
634                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
635                     else:
636                         ret += f'       {details} setting up / copying data...'
637                         sec = 0.0
638
639                     if qworker_p95 is not None:
640                         if sec > qworker_p95:
641                             ret += f'{bg("red")}>💻p95{reset()} '
642                             if details is not None:
643                                 details.slower_than_local_p95 = True
644                         else:
645                             if details is not None:
646                                 details.slower_than_local_p95 = False
647
648                     if qall_p95 is not None:
649                         if sec > qall_p95:
650                             ret += f'{bg("red")}>∀p95{reset()} '
651                             if details is not None:
652                                 details.slower_than_global_p95 = True
653                         else:
654                             details.slower_than_global_p95 = False
655                     ret += '\n'
656         return ret
657
658     def periodic_dump(self, total_bundles_submitted: int) -> None:
659         assert self.lock.locked()
660         self.total_bundles_submitted = total_bundles_submitted
661         ts = time.time()
662         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
663             print(self)
664             self.last_periodic_dump = ts
665
666
667 class RemoteWorkerSelectionPolicy(ABC):
668     """An interface definition of a policy for selecting a remote worker."""
669
670     def __init__(self):
671         self.workers: Optional[List[RemoteWorkerRecord]] = None
672
673     def register_worker_pool(self, workers: List[RemoteWorkerRecord]):
674         self.workers = workers
675
676     @abstractmethod
677     def is_worker_available(self) -> bool:
678         pass
679
680     @abstractmethod
681     def acquire_worker(
682         self, machine_to_avoid: str = None
683     ) -> Optional[RemoteWorkerRecord]:
684         pass
685
686
687 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
688     """A remote worker selector that uses weighted RNG."""
689
690     @overrides
691     def is_worker_available(self) -> bool:
692         if self.workers:
693             for worker in self.workers:
694                 if worker.count > 0:
695                     return True
696         return False
697
698     @overrides
699     def acquire_worker(
700         self, machine_to_avoid: str = None
701     ) -> Optional[RemoteWorkerRecord]:
702         grabbag = []
703         if self.workers:
704             for worker in self.workers:
705                 if worker.machine != machine_to_avoid:
706                     if worker.count > 0:
707                         for _ in range(worker.count * worker.weight):
708                             grabbag.append(worker)
709
710         if len(grabbag) == 0:
711             logger.debug(
712                 'There are no available workers that avoid %s', machine_to_avoid
713             )
714             if self.workers:
715                 for worker in self.workers:
716                     if worker.count > 0:
717                         for _ in range(worker.count * worker.weight):
718                             grabbag.append(worker)
719
720         if len(grabbag) == 0:
721             logger.warning('There are no available workers?!')
722             return None
723
724         worker = random.sample(grabbag, 1)[0]
725         assert worker.count > 0
726         worker.count -= 1
727         logger.debug('Selected worker %s', worker)
728         return worker
729
730
731 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
732     """A remote worker selector that just round robins."""
733
734     def __init__(self) -> None:
735         super().__init__()
736         self.index = 0
737
738     @overrides
739     def is_worker_available(self) -> bool:
740         if self.workers:
741             for worker in self.workers:
742                 if worker.count > 0:
743                     return True
744         return False
745
746     def _increment_index(self, index: int) -> None:
747         if self.workers:
748             index += 1
749             if index >= len(self.workers):
750                 index = 0
751             self.index = index
752
753     @overrides
754     def acquire_worker(
755         self, machine_to_avoid: str = None
756     ) -> Optional[RemoteWorkerRecord]:
757         if self.workers:
758             x = self.index
759             while True:
760                 worker = self.workers[x]
761                 if worker.machine != machine_to_avoid and worker.count > 0:
762                     worker.count -= 1
763                     self._increment_index(x)
764                     logger.debug('Selected worker %s', worker)
765                     return worker
766                 x += 1
767                 if x >= len(self.workers):
768                     x = 0
769                 if x == self.index:
770                     logger.warning('Unexpectedly could not find a worker, retrying...')
771                     return None
772         return None
773
774
775 class RemoteExecutor(BaseExecutor):
776     """An executor that uses processes on remote machines to do work.
777     To do so, it requires that a pool of remote workers to be properly
778     configured.  See instructions in
779     :class:`pyutils.parallelize.parallelize`.
780
781     Each machine in a worker pool has a *weight* and a *count*.  A
782     *weight* captures the relative speed of a processor on that worker
783     and a *count* captures the number of synchronous tasks the worker
784     can accept (i.e. the number of cpus on the machine).
785
786     To dispatch work to a remote machine, this class pickles the code
787     to be executed remotely using `cloudpickle`.  For that to work,
788     the remote machine should be running the same version of Python as
789     this machine, ideally in a virtual environment with the same
790     import libraries installed.  Differences in operating system
791     and/or processor architecture don't seem to matter for most code,
792     though.
793
794     .. warning::
795
796         Mismatches in Python version or in the version numbers of
797         third-party libraries between machines can cause problems
798         when trying to unpickle and run code remotely.
799
800     Work to be dispatched is represented in this code by creating a
801     "bundle".  Each bundle is assigned to a remote worker based on
802     heuristics captured in a :class:`RemoteWorkerSelectionPolicy`.  In
803     general, it attempts to load all workers in the pool and maximize
804     throughput.  Once assigned to a remote worker, pickled code is
805     copied to that worker via `scp` and a remote command is issued via
806     `ssh` to execute a :file:`remote_worker.py` process on the remote
807     machine.  This process unpickles the code, runs it, and produces a
808     result which is then copied back to the local machine (again via
809     `scp`) where it can be processed by local code.
810
811     You can and probably must override the path of
812     :file:`remote_worker.py` on your pool machines using the
813     `--remote_worker_helper_path` commandline argument (or by just
814     changing the default in code, see above in this file's code).
815
816     During remote work execution, this local machine acts as a
817     controller dispatching all work to the network, copying pickled
818     tasks out, and copying results back in.  It may also be a worker
819     in the pool but do not underestimate the cost of being a
820     controller -- it takes some cpu and a lot of network bandwidth.
821     The work dispatcher logic attempts to detect when a controller is
822     also a worker and reduce its load.
823
824     Some redundancy and safety provisions are made when scheduling
825     tasks to the worker pool; e.g. slower than expected tasks have
826     redundant backups tasks created, especially if there are otherwise
827     idle workers.  If a task fails repeatedly, the dispatcher consider
828     it poisoned and give up on it.
829
830     .. warning::
831
832         This executor probably only makes sense to use with
833         computationally expensive tasks such as jobs that will execute
834         for ~30 seconds or longer.
835
836         The network overhead and latency of copying work from the
837         controller (local) machine to the remote workers and copying
838         results back again is relatively high.  Especially at startup,
839         the network can become a bottleneck.  Future versions of this
840         code may attempt to split the responsibility of being a
841         controller (distributing work to pool machines).
842
843     Instructions for how to set this up are provided in
844     :class:`pyutils.parallelize.parallelize`.
845
846     See also :class:`ProcessExecutor` and :class:`ThreadExecutor`.
847
848     """
849
850     def __init__(
851         self,
852         workers: List[RemoteWorkerRecord],
853         policy: RemoteWorkerSelectionPolicy,
854     ) -> None:
855         """
856         Args:
857             workers: A list of remote workers we can call on to do tasks.
858             policy: A policy for selecting remote workers for tasks.
859         """
860
861         super().__init__()
862         self.workers = workers
863         self.policy = policy
864         self.worker_count = 0
865         for worker in self.workers:
866             self.worker_count += worker.count
867         if self.worker_count <= 0:
868             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
869             logger.critical(msg)
870             raise RemoteExecutorException(msg)
871         self.policy.register_worker_pool(self.workers)
872         self.cv = threading.Condition()
873         logger.debug(
874             'Creating %d local threads, one per remote worker.', self.worker_count
875         )
876         self._helper_executor = fut.ThreadPoolExecutor(
877             thread_name_prefix="remote_executor_helper",
878             max_workers=self.worker_count,
879         )
880         self.status = RemoteExecutorStatus(self.worker_count)
881         self.total_bundles_submitted = 0
882         self.backup_lock = threading.Lock()
883         self.last_backup = None
884         (
885             self.heartbeat_thread,
886             self.heartbeat_stop_event,
887         ) = self._run_periodic_heartbeat()
888         self.already_shutdown = False
889
890     @background_thread
891     def _run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
892         """
893         We create a background thread to invoke :meth:`_heartbeat` regularly
894         while we are scheduling work.  It does some accounting such as
895         looking for slow bundles to tag for backup creation, checking for
896         unexpected failures, and printing a fancy message on stdout.
897         """
898         while not stop_event.is_set():
899             time.sleep(5.0)
900             logger.debug('Running periodic heartbeat code...')
901             self._heartbeat()
902         logger.debug('Periodic heartbeat thread shutting down.')
903
904     def _heartbeat(self) -> None:
905         # Note: this is invoked on a background thread, not an
906         # executor thread.  Be careful what you do with it b/c it
907         # needs to get back and dump status again periodically.
908         with self.status.lock:
909             self.status.periodic_dump(self.total_bundles_submitted)
910
911             # Look for bundles to reschedule via executor.submit
912             if config.config['executors_schedule_remote_backups']:
913                 self._maybe_schedule_backup_bundles()
914
915     def _maybe_schedule_backup_bundles(self):
916         """Maybe schedule backup bundles if we see a very slow bundle."""
917
918         assert self.status.lock.locked()
919         num_done = len(self.status.finished_bundle_timings)
920         num_idle_workers = self.worker_count - self.task_count
921         now = time.time()
922         if (
923             num_done >= 2
924             and num_idle_workers > 0
925             and (self.last_backup is None or (now - self.last_backup > 9.0))
926             and self.backup_lock.acquire(blocking=False)
927         ):
928             try:
929                 assert self.backup_lock.locked()
930
931                 bundle_to_backup = None
932                 best_score = None
933                 for (
934                     worker,
935                     bundle_uuids,
936                 ) in self.status.in_flight_bundles_by_worker.items():
937
938                     # Prefer to schedule backups of bundles running on
939                     # slower machines.
940                     base_score = 0
941                     for record in self.workers:
942                         if worker.machine == record.machine:
943                             temp_score = float(record.weight)
944                             temp_score = 1.0 / temp_score
945                             temp_score *= 200.0
946                             base_score = int(temp_score)
947                             break
948
949                     for uuid in bundle_uuids:
950                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
951                         if (
952                             bundle is not None
953                             and bundle.src_bundle is None
954                             and bundle.backup_bundles is not None
955                         ):
956                             score = base_score
957
958                             # Schedule backups of bundles running
959                             # longer; especially those that are
960                             # unexpectedly slow.
961                             start_ts = self.status.start_per_bundle[uuid]
962                             if start_ts is not None:
963                                 runtime = now - start_ts
964                                 score += runtime
965                                 logger.debug(
966                                     'score[%s] => %.1f  # latency boost', bundle, score
967                                 )
968
969                                 if bundle.slower_than_local_p95:
970                                     score += runtime / 2
971                                     logger.debug(
972                                         'score[%s] => %.1f  # >worker p95',
973                                         bundle,
974                                         score,
975                                     )
976
977                                 if bundle.slower_than_global_p95:
978                                     score += runtime / 4
979                                     logger.debug(
980                                         'score[%s] => %.1f  # >global p95',
981                                         bundle,
982                                         score,
983                                     )
984
985                             # Prefer backups of bundles that don't
986                             # have backups already.
987                             backup_count = len(bundle.backup_bundles)
988                             if backup_count == 0:
989                                 score *= 2
990                             elif backup_count == 1:
991                                 score /= 2
992                             elif backup_count == 2:
993                                 score /= 8
994                             else:
995                                 score = 0
996                             logger.debug(
997                                 'score[%s] => %.1f  # {backup_count} dup backup factor',
998                                 bundle,
999                                 score,
1000                             )
1001
1002                             if score != 0 and (
1003                                 best_score is None or score > best_score
1004                             ):
1005                                 bundle_to_backup = bundle
1006                                 assert bundle is not None
1007                                 assert bundle.backup_bundles is not None
1008                                 assert bundle.src_bundle is None
1009                                 best_score = score
1010
1011                 # Note: this is all still happening on the heartbeat
1012                 # runner thread.  That's ok because
1013                 # _schedule_backup_for_bundle uses the executor to
1014                 # submit the bundle again which will cause it to be
1015                 # picked up by a worker thread and allow this thread
1016                 # to return to run future heartbeats.
1017                 if bundle_to_backup is not None:
1018                     self.last_backup = now
1019                     logger.info(
1020                         '=====> SCHEDULING BACKUP %s (score=%.1f) <=====',
1021                         bundle_to_backup,
1022                         best_score,
1023                     )
1024                     self._schedule_backup_for_bundle(bundle_to_backup)
1025             finally:
1026                 self.backup_lock.release()
1027
1028     def _is_worker_available(self) -> bool:
1029         """Is there a worker available currently?"""
1030         return self.policy.is_worker_available()
1031
1032     def _acquire_worker(
1033         self, machine_to_avoid: str = None
1034     ) -> Optional[RemoteWorkerRecord]:
1035         """Try to acquire a worker."""
1036         return self.policy.acquire_worker(machine_to_avoid)
1037
1038     def _find_available_worker_or_block(
1039         self, machine_to_avoid: str = None
1040     ) -> RemoteWorkerRecord:
1041         """Find a worker or block until one becomes available."""
1042         with self.cv:
1043             while not self._is_worker_available():
1044                 self.cv.wait()
1045             worker = self._acquire_worker(machine_to_avoid)
1046             if worker is not None:
1047                 return worker
1048         msg = "We should never reach this point in the code"
1049         logger.critical(msg)
1050         raise Exception(msg)
1051
1052     def _release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
1053         """Release a previously acquired worker."""
1054         worker = bundle.worker
1055         assert worker is not None
1056         logger.debug('Released worker %s', worker)
1057         self.status.record_release_worker(
1058             worker,
1059             bundle.uuid,
1060             was_cancelled,
1061         )
1062         with self.cv:
1063             worker.count += 1
1064             self.cv.notify()
1065         self.adjust_task_count(-1)
1066
1067     def _check_if_cancelled(self, bundle: BundleDetails) -> bool:
1068         """See if a particular bundle is cancelled.  Do not block."""
1069         with self.status.lock:
1070             if bundle.is_cancelled.wait(timeout=0.0):
1071                 logger.debug('Bundle %s is cancelled, bail out.', bundle.uuid)
1072                 bundle.was_cancelled = True
1073                 return True
1074         return False
1075
1076     def _launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
1077         """Find a worker for bundle or block until one is available."""
1078
1079         self.adjust_task_count(+1)
1080         uuid = bundle.uuid
1081         controller = bundle.controller
1082         avoid_machine = override_avoid_machine
1083         is_original = bundle.src_bundle is None
1084
1085         # Try not to schedule a backup on the same host as the original.
1086         if avoid_machine is None and bundle.src_bundle is not None:
1087             avoid_machine = bundle.src_bundle.machine
1088         worker = None
1089         while worker is None:
1090             worker = self._find_available_worker_or_block(avoid_machine)
1091         assert worker is not None
1092
1093         # Ok, found a worker.
1094         bundle.worker = worker
1095         machine = bundle.machine = worker.machine
1096         username = bundle.username = worker.username
1097         self.status.record_acquire_worker(worker, uuid)
1098         logger.debug('%s: Running bundle on %s...', bundle, worker)
1099
1100         # Before we do any work, make sure the bundle is still viable.
1101         # It may have been some time between when it was submitted and
1102         # now due to lack of worker availability and someone else may
1103         # have already finished it.
1104         if self._check_if_cancelled(bundle):
1105             try:
1106                 return self._process_work_result(bundle)
1107             except Exception:
1108                 logger.warning(
1109                     '%s: bundle says it\'s cancelled upfront but no results?!', bundle
1110                 )
1111                 self._release_worker(bundle)
1112                 if is_original:
1113                     # Weird.  We are the original owner of this
1114                     # bundle.  For it to have been cancelled, a backup
1115                     # must have already started and completed before
1116                     # we even for started.  Moreover, the backup says
1117                     # it is done but we can't find the results it
1118                     # should have copied over.  Reschedule the whole
1119                     # thing.
1120                     logger.exception(
1121                         '%s: We are the original owner thread and yet there are '
1122                         'no results for this bundle.  This is unexpected and bad. '
1123                         'Attempting an emergency retry...',
1124                         bundle,
1125                     )
1126                     return self._emergency_retry_nasty_bundle(bundle)
1127                 else:
1128                     # We're a backup and our bundle is cancelled
1129                     # before we even got started.  Do nothing and let
1130                     # the original bundle's thread worry about either
1131                     # finding the results or complaining about it.
1132                     return None
1133
1134         # Send input code / data to worker machine if it's not local.
1135         if controller not in machine:
1136             try:
1137                 cmd = (
1138                     f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
1139                 )
1140                 start_ts = time.time()
1141                 logger.info("%s: Copying work to %s via %s.", bundle, worker, cmd)
1142                 run_silently(cmd)
1143                 xfer_latency = time.time() - start_ts
1144                 logger.debug(
1145                     "%s: Copying to %s took %.1fs.", bundle, worker, xfer_latency
1146                 )
1147             except Exception:
1148                 self._release_worker(bundle)
1149                 if is_original:
1150                     # Weird.  We tried to copy the code to the worker
1151                     # and it failed...  And we're the original bundle.
1152                     # We have to retry.
1153                     logger.exception(
1154                         "%s: Failed to send instructions to the worker machine?! "
1155                         "This is not expected; we\'re the original bundle so this shouldn\'t "
1156                         "be a race condition.  Attempting an emergency retry...",
1157                         bundle,
1158                     )
1159                     return self._emergency_retry_nasty_bundle(bundle)
1160                 else:
1161                     # This is actually expected; we're a backup.
1162                     # There's a race condition where someone else
1163                     # already finished the work and removed the source
1164                     # code_file before we could copy it.  Ignore.
1165                     logger.warning(
1166                         '%s: Failed to send instructions to the worker machine... '
1167                         'We\'re a backup and this may be caused by the original (or '
1168                         'some other backup) already finishing this work.  Ignoring.',
1169                         bundle,
1170                     )
1171                     return None
1172
1173         # Kick off the work.  Note that if this fails we let
1174         # _wait_for_process deal with it.
1175         self.status.record_processing_began(uuid)
1176         helper_path = config.config['remote_worker_helper_path']
1177         cmd = (
1178             f'{SSH} {bundle.username}@{bundle.machine} '
1179             f'"{helper_path} --code_file {bundle.code_file} --result_file {bundle.result_file}"'
1180         )
1181         logger.debug(
1182             '%s: Executing %s in the background to kick off work...', bundle, cmd
1183         )
1184         p = cmd_in_background(cmd, silent=True)
1185         bundle.pid = p.pid
1186         logger.debug(
1187             '%s: Local ssh process pid=%d; remote worker is %s.', bundle, p.pid, machine
1188         )
1189         return self._wait_for_process(p, bundle, 0)
1190
1191     def _wait_for_process(
1192         self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
1193     ) -> Any:
1194         """At this point we've copied the bundle's pickled code to the remote
1195         worker and started an ssh process that should be invoking the
1196         remote worker to have it execute the user's code.  See how
1197         that's going and wait for it to complete or fail.  Note that
1198         this code is recursive: there are codepaths where we decide to
1199         stop waiting for an ssh process (because another backup seems
1200         to have finished) but then fail to fetch or parse the results
1201         from that backup and thus call ourselves to continue waiting
1202         on an active ssh process.  This is the purpose of the depth
1203         argument: to curtail potential infinite recursion by giving up
1204         eventually.
1205
1206         Args:
1207             p: the Popen record of the ssh job
1208             bundle: the bundle of work being executed remotely
1209             depth: how many retries we've made so far.  Starts at zero.
1210
1211         """
1212
1213         machine = bundle.machine
1214         assert p is not None
1215         pid = p.pid  # pid of the ssh process
1216         if depth > 3:
1217             logger.error(
1218                 "I've gotten repeated errors waiting on this bundle; giving up on pid=%d",
1219                 pid,
1220             )
1221             p.terminate()
1222             self._release_worker(bundle)
1223             return self._emergency_retry_nasty_bundle(bundle)
1224
1225         # Spin until either the ssh job we scheduled finishes the
1226         # bundle or some backup worker signals that they finished it
1227         # before we could.
1228         while True:
1229             try:
1230                 p.wait(timeout=0.25)
1231             except subprocess.TimeoutExpired:
1232                 if self._check_if_cancelled(bundle):
1233                     logger.info(
1234                         '%s: looks like another worker finished bundle...', bundle
1235                     )
1236                     break
1237             else:
1238                 logger.info("%s: pid %d (%s) is finished!", bundle, pid, machine)
1239                 p = None
1240                 break
1241
1242         # If we get here we believe the bundle is done; either the ssh
1243         # subprocess finished (hopefully successfully) or we noticed
1244         # that some other worker seems to have completed the bundle
1245         # before us and we're bailing out.
1246         try:
1247             ret = self._process_work_result(bundle)
1248             if ret is not None and p is not None:
1249                 p.terminate()
1250             return ret
1251
1252         # Something went wrong; e.g. we could not copy the results
1253         # back, cleanup after ourselves on the remote machine, or
1254         # unpickle the results we got from the remove machine.  If we
1255         # still have an active ssh subprocess, keep waiting on it.
1256         # Otherwise, time for an emergency reschedule.
1257         except Exception:
1258             logger.exception('%s: Something unexpected just happened...', bundle)
1259             if p is not None:
1260                 logger.warning(
1261                     "%s: Failed to wrap up \"done\" bundle, re-waiting on active ssh.",
1262                     bundle,
1263                 )
1264                 return self._wait_for_process(p, bundle, depth + 1)
1265             else:
1266                 self._release_worker(bundle)
1267                 return self._emergency_retry_nasty_bundle(bundle)
1268
1269     def _process_work_result(self, bundle: BundleDetails) -> Any:
1270         """A bundle seems to be completed.  Check on the results."""
1271
1272         with self.status.lock:
1273             is_original = bundle.src_bundle is None
1274             was_cancelled = bundle.was_cancelled
1275             username = bundle.username
1276             machine = bundle.machine
1277             result_file = bundle.result_file
1278             code_file = bundle.code_file
1279
1280             # Whether original or backup, if we finished first we must
1281             # fetch the results if the computation happened on a
1282             # remote machine.
1283             bundle.end_ts = time.time()
1284             if not was_cancelled:
1285                 assert bundle.machine is not None
1286                 if bundle.controller not in bundle.machine:
1287                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
1288                     logger.info(
1289                         "%s: Fetching results back from %s@%s via %s",
1290                         bundle,
1291                         username,
1292                         machine,
1293                         cmd,
1294                     )
1295
1296                     # If either of these throw they are handled in
1297                     # _wait_for_process.
1298                     attempts = 0
1299                     while True:
1300                         try:
1301                             run_silently(cmd)
1302                         except Exception as e:
1303                             attempts += 1
1304                             if attempts >= 3:
1305                                 raise e
1306                         else:
1307                             break
1308
1309                     # Cleanup remote /tmp files.
1310                     run_silently(
1311                         f'{SSH} {username}@{machine}'
1312                         f' "/bin/rm -f {code_file} {result_file}"'
1313                     )
1314                     logger.debug(
1315                         'Fetching results back took %.2fs', time.time() - bundle.end_ts
1316                     )
1317                 dur = bundle.end_ts - bundle.start_ts
1318                 self.histogram.add_item(dur)
1319
1320         # Only the original worker should unpickle the file contents
1321         # though since it's the only one whose result matters.  The
1322         # original is also the only job that may delete result_file
1323         # from disk.  Note that the original may have been cancelled
1324         # if one of the backups finished first; it still must read the
1325         # result from disk.  It still does that here with is_cancelled
1326         # set.
1327         if is_original:
1328             logger.debug("%s: Unpickling %s.", bundle, result_file)
1329             try:
1330                 with open(result_file, 'rb') as rb:
1331                     serialized = rb.read()
1332                 result = cloudpickle.loads(serialized)
1333             except Exception as e:
1334                 logger.exception('Failed to load %s... this is bad news.', result_file)
1335                 self._release_worker(bundle)
1336
1337                 # Re-raise the exception; the code in _wait_for_process may
1338                 # decide to _emergency_retry_nasty_bundle here.
1339                 raise e
1340             logger.debug('Removing local (master) %s and %s.', code_file, result_file)
1341             os.remove(result_file)
1342             os.remove(code_file)
1343
1344             # Notify any backups that the original is done so they
1345             # should stop ASAP.  Do this whether or not we
1346             # finished first since there could be more than one
1347             # backup.
1348             if bundle.backup_bundles is not None:
1349                 for backup in bundle.backup_bundles:
1350                     logger.debug(
1351                         '%s: Notifying backup %s that it\'s cancelled',
1352                         bundle,
1353                         backup.uuid,
1354                     )
1355                     backup.is_cancelled.set()
1356
1357         # This is a backup job and, by now, we have already fetched
1358         # the bundle results.
1359         else:
1360             # Backup results don't matter, they just need to leave the
1361             # result file in the right place for their originals to
1362             # read/unpickle later.
1363             result = None
1364
1365             # Tell the original to stop if we finished first.
1366             if not was_cancelled:
1367                 orig_bundle = bundle.src_bundle
1368                 assert orig_bundle is not None
1369                 logger.debug(
1370                     '%s: Notifying original %s we beat them to it.',
1371                     bundle,
1372                     orig_bundle.uuid,
1373                 )
1374                 orig_bundle.is_cancelled.set()
1375         self._release_worker(bundle, was_cancelled=was_cancelled)
1376         return result
1377
1378     def _create_original_bundle(self, pickle, function_name: str):
1379         """Creates a bundle that is not a backup of any other bundle but
1380         rather represents a user task.
1381         """
1382
1383         uuid = string_utils.generate_uuid(omit_dashes=True)
1384         code_file = f'/tmp/{uuid}.code.bin'
1385         result_file = f'/tmp/{uuid}.result.bin'
1386
1387         logger.debug('Writing pickled code to %s', code_file)
1388         with open(code_file, 'wb') as wb:
1389             wb.write(pickle)
1390
1391         bundle = BundleDetails(
1392             pickled_code=pickle,
1393             uuid=uuid,
1394             function_name=function_name,
1395             worker=None,
1396             username=None,
1397             machine=None,
1398             controller=platform.node(),
1399             code_file=code_file,
1400             result_file=result_file,
1401             pid=0,
1402             start_ts=time.time(),
1403             end_ts=0.0,
1404             slower_than_local_p95=False,
1405             slower_than_global_p95=False,
1406             src_bundle=None,
1407             is_cancelled=threading.Event(),
1408             was_cancelled=False,
1409             backup_bundles=[],
1410             failure_count=0,
1411         )
1412         self.status.record_bundle_details(bundle)
1413         logger.debug('%s: Created an original bundle', bundle)
1414         return bundle
1415
1416     def _create_backup_bundle(self, src_bundle: BundleDetails):
1417         """Creates a bundle that is a backup of another bundle that is
1418         running too slowly."""
1419
1420         assert self.status.lock.locked()
1421         assert src_bundle.backup_bundles is not None
1422         n = len(src_bundle.backup_bundles)
1423         uuid = src_bundle.uuid + f'_backup#{n}'
1424
1425         backup_bundle = BundleDetails(
1426             pickled_code=src_bundle.pickled_code,
1427             uuid=uuid,
1428             function_name=src_bundle.function_name,
1429             worker=None,
1430             username=None,
1431             machine=None,
1432             controller=src_bundle.controller,
1433             code_file=src_bundle.code_file,
1434             result_file=src_bundle.result_file,
1435             pid=0,
1436             start_ts=time.time(),
1437             end_ts=0.0,
1438             slower_than_local_p95=False,
1439             slower_than_global_p95=False,
1440             src_bundle=src_bundle,
1441             is_cancelled=threading.Event(),
1442             was_cancelled=False,
1443             backup_bundles=None,  # backup backups not allowed
1444             failure_count=0,
1445         )
1446         src_bundle.backup_bundles.append(backup_bundle)
1447         self.status.record_bundle_details_already_locked(backup_bundle)
1448         logger.debug('%s: Created a backup bundle', backup_bundle)
1449         return backup_bundle
1450
1451     def _schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1452         """Schedule a backup of src_bundle."""
1453
1454         assert self.status.lock.locked()
1455         assert src_bundle is not None
1456         backup_bundle = self._create_backup_bundle(src_bundle)
1457         logger.debug(
1458             '%s/%s: Scheduling backup for execution...',
1459             backup_bundle.uuid,
1460             backup_bundle.function_name,
1461         )
1462         self._helper_executor.submit(self._launch, backup_bundle)
1463
1464         # Results from backups don't matter; if they finish first
1465         # they will move the result_file to this machine and let
1466         # the original pick them up and unpickle them (and return
1467         # a result).
1468
1469     def _emergency_retry_nasty_bundle(
1470         self, bundle: BundleDetails
1471     ) -> Optional[fut.Future]:
1472         """Something unexpectedly failed with bundle.  Either retry it
1473         from the beginning or throw in the towel and give up on it."""
1474
1475         is_original = bundle.src_bundle is None
1476         bundle.worker = None
1477         avoid_last_machine = bundle.machine
1478         bundle.machine = None
1479         bundle.username = None
1480         bundle.failure_count += 1
1481         if is_original:
1482             retry_limit = 3
1483         else:
1484             retry_limit = 2
1485
1486         if bundle.failure_count > retry_limit:
1487             logger.error(
1488                 '%s: Tried this bundle too many times already (%dx); giving up.',
1489                 bundle,
1490                 retry_limit,
1491             )
1492             if is_original:
1493                 raise RemoteExecutorException(
1494                     f'{bundle}: This bundle can\'t be completed despite several backups and retries',
1495                 )
1496             logger.error(
1497                 '%s: At least it\'s only a backup; better luck with the others.',
1498                 bundle,
1499             )
1500             return None
1501         else:
1502             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1503             logger.warning(msg)
1504             warnings.warn(msg)
1505             return self._launch(bundle, avoid_last_machine)
1506
1507     @overrides
1508     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1509         """Submit work to be done.  This is the user entry point of this
1510         class."""
1511         if self.already_shutdown:
1512             raise Exception('Submitted work after shutdown.')
1513         pickle = _make_cloud_pickle(function, *args, **kwargs)
1514         bundle = self._create_original_bundle(pickle, function.__name__)
1515         self.total_bundles_submitted += 1
1516         return self._helper_executor.submit(self._launch, bundle)
1517
1518     @overrides
1519     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
1520         """Shutdown the executor."""
1521         if not self.already_shutdown:
1522             logging.debug('Shutting down RemoteExecutor %s', self.title)
1523             self.heartbeat_stop_event.set()
1524             self.heartbeat_thread.join()
1525             self._helper_executor.shutdown(wait)
1526             if not quiet:
1527                 print(self.histogram.__repr__(label_formatter='%ds'))
1528             self.already_shutdown = True
1529
1530
1531 class RemoteWorkerPoolProvider:
1532     @abstractmethod
1533     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
1534         pass
1535
1536
1537 @persistent.persistent_autoloaded_singleton()  # type: ignore
1538 class ConfigRemoteWorkerPoolProvider(
1539     RemoteWorkerPoolProvider, persistent.JsonFileBasedPersistent
1540 ):
1541     def __init__(self, json_remote_worker_pool: Dict[str, Any]):
1542         self.remote_worker_pool: List[RemoteWorkerRecord] = []
1543         for record in json_remote_worker_pool['remote_worker_records']:
1544             self.remote_worker_pool.append(
1545                 dataclass_utils.dataclass_from_dict(RemoteWorkerRecord, record)
1546             )
1547         assert len(self.remote_worker_pool) > 0
1548
1549     @overrides
1550     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
1551         return self.remote_worker_pool
1552
1553     @overrides
1554     def get_persistent_data(self) -> List[RemoteWorkerRecord]:
1555         return self.remote_worker_pool
1556
1557     @staticmethod
1558     @overrides
1559     def get_filename() -> str:
1560         return type_utils.unwrap_optional(config.config['remote_worker_records_file'])
1561
1562     @staticmethod
1563     @overrides
1564     def should_we_load_data(filename: str) -> bool:
1565         return True
1566
1567     @staticmethod
1568     @overrides
1569     def should_we_save_data(filename: str) -> bool:
1570         return False
1571
1572
1573 @singleton
1574 class DefaultExecutors(object):
1575     """A container for a default thread, process and remote executor.
1576     These are not created until needed and we take care to clean up
1577     before process exit automatically for the caller's convenience.
1578     Instead of creating your own executor, consider using the one
1579     from this pool.  e.g.::
1580
1581         @par.parallelize(method=par.Method.PROCESS)
1582         def do_work(
1583             solutions: List[Work],
1584             shard_num: int,
1585             ...
1586         ):
1587             <do the work>
1588
1589
1590         def start_do_work(all_work: List[Work]):
1591             shards = []
1592             logger.debug('Sharding work into groups of 10.')
1593             for subset in list_utils.shard(all_work, 10):
1594                 shards.append([x for x in subset])
1595
1596             logger.debug('Kicking off helper pool.')
1597             try:
1598                 for n, shard in enumerate(shards):
1599                     results.append(
1600                         do_work(
1601                             shard, n, shared_cache.get_name(), max_letter_pop_per_word
1602                         )
1603                     )
1604                 smart_future.wait_all(results)
1605             finally:
1606                 # Note: if you forget to do this it will clean itself up
1607                 # during program termination including tearing down any
1608                 # active ssh connections.
1609                 executors.DefaultExecutors().process_pool().shutdown()
1610     """
1611
1612     def __init__(self):
1613         self.thread_executor: Optional[ThreadExecutor] = None
1614         self.process_executor: Optional[ProcessExecutor] = None
1615         self.remote_executor: Optional[RemoteExecutor] = None
1616
1617     @staticmethod
1618     def _ping(host) -> bool:
1619         logger.debug('RUN> ping -c 1 %s', host)
1620         try:
1621             x = cmd_exitcode(
1622                 f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0
1623             )
1624             return x == 0
1625         except Exception:
1626             return False
1627
1628     def thread_pool(self) -> ThreadExecutor:
1629         if self.thread_executor is None:
1630             self.thread_executor = ThreadExecutor()
1631         return self.thread_executor
1632
1633     def process_pool(self) -> ProcessExecutor:
1634         if self.process_executor is None:
1635             self.process_executor = ProcessExecutor()
1636         return self.process_executor
1637
1638     def remote_pool(self) -> RemoteExecutor:
1639         if self.remote_executor is None:
1640             logger.info('Looking for some helper machines...')
1641             provider = ConfigRemoteWorkerPoolProvider()
1642             all_machines = provider.get_remote_workers()
1643             pool = []
1644
1645             # Make sure we can ping each machine.
1646             for record in all_machines:
1647                 if self._ping(record.machine):
1648                     logger.info('%s is alive / responding to pings', record.machine)
1649                     pool.append(record)
1650
1651             # The controller machine has a lot to do; go easy on it.
1652             for record in pool:
1653                 if record.machine == platform.node() and record.count > 1:
1654                     logger.info('Reducing workload for %s.', record.machine)
1655                     record.count = max(int(record.count / 2), 1)
1656
1657             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1658             policy.register_worker_pool(pool)
1659             self.remote_executor = RemoteExecutor(pool, policy)
1660         return self.remote_executor
1661
1662     def shutdown(self) -> None:
1663         if self.thread_executor is not None:
1664             self.thread_executor.shutdown(wait=True, quiet=True)
1665             self.thread_executor = None
1666         if self.process_executor is not None:
1667             self.process_executor.shutdown(wait=True, quiet=True)
1668             self.process_executor = None
1669         if self.remote_executor is not None:
1670             self.remote_executor.shutdown(wait=True, quiet=True)
1671             self.remote_executor = None