fd70e327b75b81a25e6e20a470de3c36a296d125
[pyutils.git] / src / pyutils / parallelize / executors.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3
4 # © Copyright 2021-2022, Scott Gasch
5
6 """Defines three executors: a thread executor for doing work using a
7 threadpool, a process executor for doing work in other processes on
8 the same machine and a remote executor for farming out work to other
9 machines.
10
11 Also defines DefaultExecutors which is a container for references to
12 global executors / worker pools with automatic shutdown semantics."""
13
14 from __future__ import annotations
15
16 import concurrent.futures as fut
17 import logging
18 import os
19 import platform
20 import random
21 import subprocess
22 import threading
23 import time
24 import warnings
25 from abc import ABC, abstractmethod
26 from collections import defaultdict
27 from dataclasses import dataclass, fields
28 from typing import Any, Callable, Dict, List, Optional, Set
29
30 import cloudpickle  # type: ignore
31 from overrides import overrides
32
33 import pyutils.typez.histogram as hist
34 from pyutils import argparse_utils, config, math_utils, persistent, string_utils
35 from pyutils.ansi import bg, fg, reset, underline
36 from pyutils.decorator_utils import singleton
37 from pyutils.exec_utils import cmd_exitcode, cmd_in_background, run_silently
38 from pyutils.parallelize.thread_utils import background_thread
39
40 logger = logging.getLogger(__name__)
41
42 parser = config.add_commandline_args(
43     f"Executors ({__file__})", "Args related to processing executors."
44 )
45 parser.add_argument(
46     '--executors_threadpool_size',
47     type=int,
48     metavar='#THREADS',
49     help='Number of threads in the default threadpool, leave unset for default',
50     default=None,
51 )
52 parser.add_argument(
53     '--executors_processpool_size',
54     type=int,
55     metavar='#PROCESSES',
56     help='Number of processes in the default processpool, leave unset for default',
57     default=None,
58 )
59 parser.add_argument(
60     '--executors_schedule_remote_backups',
61     default=True,
62     action=argparse_utils.ActionNoYes,
63     help='Should we schedule duplicative backup work if a remote bundle is slow',
64 )
65 parser.add_argument(
66     '--executors_max_bundle_failures',
67     type=int,
68     default=3,
69     metavar='#FAILURES',
70     help='Maximum number of failures before giving up on a bundle',
71 )
72 parser.add_argument(
73     '--remote_worker_records_file',
74     type=str,
75     metavar='FILENAME',
76     help='Path of the remote worker records file (JSON)',
77     default=f'{os.environ.get("HOME", ".")}/.remote_worker_records',
78 )
79
80
81 SSH = '/usr/bin/ssh -oForwardX11=no'
82 SCP = '/usr/bin/scp -C'
83
84
85 def _make_cloud_pickle(fun, *args, **kwargs):
86     """Internal helper to create cloud pickles."""
87     logger.debug("Making cloudpickled bundle at %s", fun.__name__)
88     return cloudpickle.dumps((fun, args, kwargs))
89
90
91 class BaseExecutor(ABC):
92     """The base executor interface definition.  The interface for
93     :class:`ProcessExecutor`, :class:`RemoteExecutor`, and
94     :class:`ThreadExecutor`.
95     """
96
97     def __init__(self, *, title=''):
98         self.title = title
99         self.histogram = hist.SimpleHistogram(
100             hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50)
101         )
102         self.task_count = 0
103
104     @abstractmethod
105     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
106         pass
107
108     @abstractmethod
109     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
110         pass
111
112     def shutdown_if_idle(self, *, quiet: bool = False) -> bool:
113         """Shutdown the executor and return True if the executor is idle
114         (i.e. there are no pending or active tasks).  Return False
115         otherwise.  Note: this should only be called by the launcher
116         process.
117
118         """
119         if self.task_count == 0:
120             self.shutdown(wait=True, quiet=quiet)
121             return True
122         return False
123
124     def adjust_task_count(self, delta: int) -> None:
125         """Change the task count.  Note: do not call this method from a
126         worker, it should only be called by the launcher process /
127         thread / machine.
128
129         """
130         self.task_count += delta
131         logger.debug('Adjusted task count by %d to %d.', delta, self.task_count)
132
133     def get_task_count(self) -> int:
134         """Change the task count.  Note: do not call this method from a
135         worker, it should only be called by the launcher process /
136         thread / machine.
137
138         """
139         return self.task_count
140
141
142 class ThreadExecutor(BaseExecutor):
143     """A threadpool executor.  This executor uses python threads to
144     schedule tasks.  Note that, at least as of python3.10, because of
145     the global lock in the interpreter itself, these do not
146     parallelize very well so this class is useful mostly for non-CPU
147     intensive tasks.
148
149     See also :class:`ProcessExecutor` and :class:`RemoteExecutor`.
150     """
151
152     def __init__(self, max_workers: Optional[int] = None):
153         super().__init__()
154         workers = None
155         if max_workers is not None:
156             workers = max_workers
157         elif 'executors_threadpool_size' in config.config:
158             workers = config.config['executors_threadpool_size']
159         if workers is not None:
160             logger.debug('Creating threadpool executor with %d workers', workers)
161         else:
162             logger.debug('Creating a default sized threadpool executor')
163         self._thread_pool_executor = fut.ThreadPoolExecutor(
164             max_workers=workers, thread_name_prefix="thread_executor_helper"
165         )
166         self.already_shutdown = False
167
168     # This is run on a different thread; do not adjust task count here.
169     @staticmethod
170     def run_local_bundle(fun, *args, **kwargs):
171         logger.debug("Running local bundle at %s", fun.__name__)
172         result = fun(*args, **kwargs)
173         return result
174
175     @overrides
176     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
177         if self.already_shutdown:
178             raise Exception('Submitted work after shutdown.')
179         self.adjust_task_count(+1)
180         newargs = []
181         newargs.append(function)
182         for arg in args:
183             newargs.append(arg)
184         start = time.time()
185         result = self._thread_pool_executor.submit(
186             ThreadExecutor.run_local_bundle, *newargs, **kwargs
187         )
188         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
189         result.add_done_callback(lambda _: self.adjust_task_count(-1))
190         return result
191
192     @overrides
193     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
194         if not self.already_shutdown:
195             logger.debug('Shutting down threadpool executor %s', self.title)
196             self._thread_pool_executor.shutdown(wait)
197             if not quiet:
198                 print(self.histogram.__repr__(label_formatter='%ds'))
199             self.already_shutdown = True
200
201
202 class ProcessExecutor(BaseExecutor):
203     """An executor which runs tasks in child processes.
204
205     See also :class:`ThreadExecutor` and :class:`RemoteExecutor`.
206     """
207
208     def __init__(self, max_workers=None):
209         super().__init__()
210         workers = None
211         if max_workers is not None:
212             workers = max_workers
213         elif 'executors_processpool_size' in config.config:
214             workers = config.config['executors_processpool_size']
215         if workers is not None:
216             logger.debug('Creating processpool executor with %d workers.', workers)
217         else:
218             logger.debug('Creating a default sized processpool executor')
219         self._process_executor = fut.ProcessPoolExecutor(
220             max_workers=workers,
221         )
222         self.already_shutdown = False
223
224     # This is run in another process; do not adjust task count here.
225     @staticmethod
226     def run_cloud_pickle(pickle):
227         fun, args, kwargs = cloudpickle.loads(pickle)
228         logger.debug("Running pickled bundle at %s", fun.__name__)
229         result = fun(*args, **kwargs)
230         return result
231
232     @overrides
233     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
234         if self.already_shutdown:
235             raise Exception('Submitted work after shutdown.')
236         start = time.time()
237         self.adjust_task_count(+1)
238         pickle = _make_cloud_pickle(function, *args, **kwargs)
239         result = self._process_executor.submit(ProcessExecutor.run_cloud_pickle, pickle)
240         result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start))
241         result.add_done_callback(lambda _: self.adjust_task_count(-1))
242         return result
243
244     @overrides
245     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
246         if not self.already_shutdown:
247             logger.debug('Shutting down processpool executor %s', self.title)
248             self._process_executor.shutdown(wait)
249             if not quiet:
250                 print(self.histogram.__repr__(label_formatter='%ds'))
251             self.already_shutdown = True
252
253     def __getstate__(self):
254         state = self.__dict__.copy()
255         state['_process_executor'] = None
256         return state
257
258
259 class RemoteExecutorException(Exception):
260     """Thrown when a bundle cannot be executed despite several retries."""
261
262     pass
263
264
265 @dataclass
266 class RemoteWorkerRecord:
267     """A record of info about a remote worker."""
268
269     username: str
270     """Username we can ssh into on this machine to run work."""
271
272     machine: str
273     """Machine address / name."""
274
275     weight: int
276     """Relative probability for the weighted policy to select this
277     machine for scheduling work."""
278
279     count: int
280     """If this machine is selected, what is the maximum number of task
281     that it can handle?"""
282
283     def __hash__(self):
284         return hash((self.username, self.machine))
285
286     def __repr__(self):
287         return f'{self.username}@{self.machine}'
288
289
290 @dataclass
291 class BundleDetails:
292     """All info necessary to define some unit of work that needs to be
293     done, where it is being run, its state, whether it is an original
294     bundle of a backup bundle, how many times it has failed, etc...
295     """
296
297     pickled_code: bytes
298     """The code to run, cloud pickled"""
299
300     uuid: str
301     """A unique identifier"""
302
303     function_name: str
304     """The name of the function we pickled"""
305
306     worker: Optional[RemoteWorkerRecord]
307     """The remote worker running this bundle or None if none (yet)"""
308
309     username: Optional[str]
310     """The remote username running this bundle or None if none (yet)"""
311
312     machine: Optional[str]
313     """The remote machine running this bundle or None if none (yet)"""
314
315     hostname: str
316     """The controller machine"""
317
318     code_file: str
319     """A unique filename to hold the work to be done"""
320
321     result_file: str
322     """Where the results should be placed / read from"""
323
324     pid: int
325     """The process id of the local subprocess watching the ssh connection
326     to the remote machine"""
327
328     start_ts: float
329     """Starting time"""
330
331     end_ts: float
332     """Ending time"""
333
334     slower_than_local_p95: bool
335     """Currently slower then 95% of other bundles on remote host"""
336
337     slower_than_global_p95: bool
338     """Currently slower than 95% of other bundles globally"""
339
340     src_bundle: Optional[BundleDetails]
341     """If this is a backup bundle, this points to the original bundle
342     that it's backing up.  None otherwise."""
343
344     is_cancelled: threading.Event
345     """An event that can be signaled to indicate this bundle is cancelled.
346     This is set when another copy (backup or original) of this work has
347     completed successfully elsewhere."""
348
349     was_cancelled: bool
350     """True if this bundle was cancelled, False if it finished normally"""
351
352     backup_bundles: Optional[List[BundleDetails]]
353     """If we've created backups of this bundle, this is the list of them"""
354
355     failure_count: int
356     """How many times has this bundle failed already?"""
357
358     def __repr__(self):
359         uuid = self.uuid
360         if uuid[-9:-2] == '_backup':
361             uuid = uuid[:-9]
362             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
363         else:
364             suffix = uuid[-6:]
365
366         # We colorize the uuid based on some bits from it to make them
367         # stand out in the logging and help a reader correlate log messages
368         # related to the same bundle.
369         colorz = [
370             fg('violet red'),
371             fg('red'),
372             fg('orange'),
373             fg('peach orange'),
374             fg('yellow'),
375             fg('marigold yellow'),
376             fg('green yellow'),
377             fg('tea green'),
378             fg('cornflower blue'),
379             fg('turquoise blue'),
380             fg('tropical blue'),
381             fg('lavender purple'),
382             fg('medium purple'),
383         ]
384         c = colorz[int(uuid[-2:], 16) % len(colorz)]
385         function_name = (
386             self.function_name if self.function_name is not None else 'nofname'
387         )
388         machine = self.machine if self.machine is not None else 'nomachine'
389         return f'{c}{suffix}/{function_name}/{machine}{reset()}'
390
391
392 class RemoteExecutorStatus:
393     """A status 'scoreboard' for a remote executor tracking various
394     metrics and able to render a periodic dump of global state.
395     """
396
397     def __init__(self, total_worker_count: int) -> None:
398         """C'tor.
399
400         Args:
401             total_worker_count: number of workers in the pool
402
403         """
404         self.worker_count: int = total_worker_count
405         self.known_workers: Set[RemoteWorkerRecord] = set()
406         self.start_time: float = time.time()
407         self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
408         self.end_per_bundle: Dict[str, float] = defaultdict(float)
409         self.finished_bundle_timings_per_worker: Dict[
410             RemoteWorkerRecord, math_utils.NumericPopulation
411         ] = {}
412         self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {}
413         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
414         self.finished_bundle_timings: math_utils.NumericPopulation = (
415             math_utils.NumericPopulation()
416         )
417         self.last_periodic_dump: Optional[float] = None
418         self.total_bundles_submitted: int = 0
419
420         # Protects reads and modification using self.  Also used
421         # as a memory fence for modifications to bundle.
422         self.lock: threading.Lock = threading.Lock()
423
424     def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None:
425         """Record that bundle with uuid is assigned to a particular worker.
426
427         Args:
428             worker: the record of the worker to which uuid is assigned
429             uuid: the uuid of a bundle that has been assigned to a worker
430         """
431         with self.lock:
432             self.record_acquire_worker_already_locked(worker, uuid)
433
434     def record_acquire_worker_already_locked(
435         self, worker: RemoteWorkerRecord, uuid: str
436     ) -> None:
437         """Same as above but an entry point that doesn't acquire the lock
438         for codepaths where it's already held."""
439         assert self.lock.locked()
440         self.known_workers.add(worker)
441         self.start_per_bundle[uuid] = None
442         x = self.in_flight_bundles_by_worker.get(worker, set())
443         x.add(uuid)
444         self.in_flight_bundles_by_worker[worker] = x
445
446     def record_bundle_details(self, details: BundleDetails) -> None:
447         """Register the details about a bundle of work."""
448         with self.lock:
449             self.record_bundle_details_already_locked(details)
450
451     def record_bundle_details_already_locked(self, details: BundleDetails) -> None:
452         """Same as above but for codepaths that already hold the lock."""
453         assert self.lock.locked()
454         self.bundle_details_by_uuid[details.uuid] = details
455
456     def record_release_worker(
457         self,
458         worker: RemoteWorkerRecord,
459         uuid: str,
460         was_cancelled: bool,
461     ) -> None:
462         """Record that a bundle has released a worker."""
463         with self.lock:
464             self.record_release_worker_already_locked(worker, uuid, was_cancelled)
465
466     def record_release_worker_already_locked(
467         self,
468         worker: RemoteWorkerRecord,
469         uuid: str,
470         was_cancelled: bool,
471     ) -> None:
472         """Same as above but for codepaths that already hold the lock."""
473         assert self.lock.locked()
474         ts = time.time()
475         self.end_per_bundle[uuid] = ts
476         self.in_flight_bundles_by_worker[worker].remove(uuid)
477         if not was_cancelled:
478             start = self.start_per_bundle[uuid]
479             assert start is not None
480             bundle_latency = ts - start
481             x = self.finished_bundle_timings_per_worker.get(
482                 worker, math_utils.NumericPopulation()
483             )
484             x.add_number(bundle_latency)
485             self.finished_bundle_timings_per_worker[worker] = x
486             self.finished_bundle_timings.add_number(bundle_latency)
487
488     def record_processing_began(self, uuid: str):
489         """Record when work on a bundle begins."""
490         with self.lock:
491             self.start_per_bundle[uuid] = time.time()
492
493     def total_in_flight(self) -> int:
494         """How many bundles are in flight currently?"""
495         assert self.lock.locked()
496         total_in_flight = 0
497         for worker in self.known_workers:
498             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
499         return total_in_flight
500
501     def total_idle(self) -> int:
502         """How many idle workers are there currently?"""
503         assert self.lock.locked()
504         return self.worker_count - self.total_in_flight()
505
506     def __repr__(self):
507         assert self.lock.locked()
508         ts = time.time()
509         total_finished = len(self.finished_bundle_timings)
510         total_in_flight = self.total_in_flight()
511         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
512         qall = None
513         if len(self.finished_bundle_timings) > 1:
514             qall_median = self.finished_bundle_timings.get_median()
515             qall_p95 = self.finished_bundle_timings.get_percentile(95)
516             ret += (
517                 f'⏱=∀p50:{qall_median:.1f}s, ∀p95:{qall_p95:.1f}s, total={ts-self.start_time:.1f}s, '
518                 f'✅={total_finished}/{self.total_bundles_submitted}, '
519                 f'💻n={total_in_flight}/{self.worker_count}\n'
520             )
521         else:
522             ret += (
523                 f'⏱={ts-self.start_time:.1f}s, '
524                 f'✅={total_finished}/{self.total_bundles_submitted}, '
525                 f'💻n={total_in_flight}/{self.worker_count}\n'
526             )
527
528         for worker in self.known_workers:
529             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
530             timings = self.finished_bundle_timings_per_worker.get(
531                 worker, math_utils.NumericPopulation()
532             )
533             count = len(timings)
534             qworker_median = None
535             qworker_p95 = None
536             if count > 1:
537                 qworker_median = timings.get_median()
538                 qworker_p95 = timings.get_percentile(95)
539                 ret += f' 💻p50: {qworker_median:.1f}s, 💻p95: {qworker_p95:.1f}s\n'
540             else:
541                 ret += '\n'
542             if count > 0:
543                 ret += f'    ...finished {count} total bundle(s) so far\n'
544             in_flight = len(self.in_flight_bundles_by_worker[worker])
545             if in_flight > 0:
546                 ret += f'    ...{in_flight} bundles currently in flight:\n'
547                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
548                     details = self.bundle_details_by_uuid.get(bundle_uuid, None)
549                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
550                     if self.start_per_bundle[bundle_uuid] is not None:
551                         sec = ts - self.start_per_bundle[bundle_uuid]
552                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
553                     else:
554                         ret += f'       {details} setting up / copying data...'
555                         sec = 0.0
556
557                     if qworker_p95 is not None:
558                         if sec > qworker_p95:
559                             ret += f'{bg("red")}>💻p95{reset()} '
560                             if details is not None:
561                                 details.slower_than_local_p95 = True
562                         else:
563                             if details is not None:
564                                 details.slower_than_local_p95 = False
565
566                     if qall is not None:
567                         if sec > qall[1]:
568                             ret += f'{bg("red")}>∀p95{reset()} '
569                             if details is not None:
570                                 details.slower_than_global_p95 = True
571                         else:
572                             details.slower_than_global_p95 = False
573                     ret += '\n'
574         return ret
575
576     def periodic_dump(self, total_bundles_submitted: int) -> None:
577         assert self.lock.locked()
578         self.total_bundles_submitted = total_bundles_submitted
579         ts = time.time()
580         if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0:
581             print(self)
582             self.last_periodic_dump = ts
583
584
585 class RemoteWorkerSelectionPolicy(ABC):
586     """A policy for selecting a remote worker base class."""
587
588     def __init__(self):
589         self.workers: Optional[List[RemoteWorkerRecord]] = None
590
591     def register_worker_pool(self, workers: List[RemoteWorkerRecord]):
592         self.workers = workers
593
594     @abstractmethod
595     def is_worker_available(self) -> bool:
596         pass
597
598     @abstractmethod
599     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
600         pass
601
602
603 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
604     """A remote worker selector that uses weighted RNG."""
605
606     @overrides
607     def is_worker_available(self) -> bool:
608         if self.workers:
609             for worker in self.workers:
610                 if worker.count > 0:
611                     return True
612         return False
613
614     @overrides
615     def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
616         grabbag = []
617         if self.workers:
618             for worker in self.workers:
619                 if worker.machine != machine_to_avoid:
620                     if worker.count > 0:
621                         for _ in range(worker.count * worker.weight):
622                             grabbag.append(worker)
623
624         if len(grabbag) == 0:
625             logger.debug(
626                 'There are no available workers that avoid %s', machine_to_avoid
627             )
628             if self.workers:
629                 for worker in self.workers:
630                     if worker.count > 0:
631                         for _ in range(worker.count * worker.weight):
632                             grabbag.append(worker)
633
634         if len(grabbag) == 0:
635             logger.warning('There are no available workers?!')
636             return None
637
638         worker = random.sample(grabbag, 1)[0]
639         assert worker.count > 0
640         worker.count -= 1
641         logger.debug('Selected worker %s', worker)
642         return worker
643
644
645 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
646     """A remote worker selector that just round robins."""
647
648     def __init__(self) -> None:
649         super().__init__()
650         self.index = 0
651
652     @overrides
653     def is_worker_available(self) -> bool:
654         if self.workers:
655             for worker in self.workers:
656                 if worker.count > 0:
657                     return True
658         return False
659
660     @overrides
661     def acquire_worker(
662         self, machine_to_avoid: str = None
663     ) -> Optional[RemoteWorkerRecord]:
664         if self.workers:
665             x = self.index
666             while True:
667                 worker = self.workers[x]
668                 if worker.count > 0:
669                     worker.count -= 1
670                     x += 1
671                     if x >= len(self.workers):
672                         x = 0
673                     self.index = x
674                     logger.debug('Selected worker %s', worker)
675                     return worker
676                 x += 1
677                 if x >= len(self.workers):
678                     x = 0
679                 if x == self.index:
680                     logger.warning('Unexpectedly could not find a worker, retrying...')
681                     return None
682         return None
683
684
685 class RemoteExecutor(BaseExecutor):
686     """An executor that uses processes on remote machines to do work.  This
687     works by creating "bundles" of work with pickled code in each to be
688     executed.  Each bundle is assigned a remote worker based on some policy
689     heuristics.  Once assigned to a remote worker, a local subprocess is
690     created.  It copies the pickled code to the remote machine via ssh/scp
691     and then starts up work on the remote machine again using ssh.  When
692     the work is complete it copies the results back to the local machine.
693
694     So there is essentially one "controller" machine (which may also be
695     in the remote executor pool and therefore do task work in addition to
696     controlling) and N worker machines.  This code runs on the controller
697     whereas on the worker machines we invoke pickled user code via a
698     shim in :file:`remote_worker.py`.
699
700     Some redundancy and safety provisions are made; e.g. slower than
701     expected tasks have redundant backups created and if a task fails
702     repeatedly we consider it poisoned and give up on it.
703
704     .. warning::
705
706         The network overhead / latency of copying work from the
707         controller machine to the remote workers is relatively high.
708         This executor probably only makes sense to use with
709         computationally expensive tasks such as jobs that will execute
710         for ~30 seconds or longer.
711
712     See also :class:`ProcessExecutor` and :class:`ThreadExecutor`.
713     """
714
715     def __init__(
716         self,
717         workers: List[RemoteWorkerRecord],
718         policy: RemoteWorkerSelectionPolicy,
719     ) -> None:
720         """C'tor.
721
722         Args:
723             workers: A list of remote workers we can call on to do tasks.
724             policy: A policy for selecting remote workers for tasks.
725         """
726
727         super().__init__()
728         self.workers = workers
729         self.policy = policy
730         self.worker_count = 0
731         for worker in self.workers:
732             self.worker_count += worker.count
733         if self.worker_count <= 0:
734             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
735             logger.critical(msg)
736             raise RemoteExecutorException(msg)
737         self.policy.register_worker_pool(self.workers)
738         self.cv = threading.Condition()
739         logger.debug(
740             'Creating %d local threads, one per remote worker.', self.worker_count
741         )
742         self._helper_executor = fut.ThreadPoolExecutor(
743             thread_name_prefix="remote_executor_helper",
744             max_workers=self.worker_count,
745         )
746         self.status = RemoteExecutorStatus(self.worker_count)
747         self.total_bundles_submitted = 0
748         self.backup_lock = threading.Lock()
749         self.last_backup = None
750         (
751             self.heartbeat_thread,
752             self.heartbeat_stop_event,
753         ) = self._run_periodic_heartbeat()
754         self.already_shutdown = False
755
756     @background_thread
757     def _run_periodic_heartbeat(self, stop_event: threading.Event) -> None:
758         """
759         We create a background thread to invoke :meth:`_heartbeat` regularly
760         while we are scheduling work.  It does some accounting such as
761         looking for slow bundles to tag for backup creation, checking for
762         unexpected failures, and printing a fancy message on stdout.
763         """
764         while not stop_event.is_set():
765             time.sleep(5.0)
766             logger.debug('Running periodic heartbeat code...')
767             self._heartbeat()
768         logger.debug('Periodic heartbeat thread shutting down.')
769
770     def _heartbeat(self) -> None:
771         # Note: this is invoked on a background thread, not an
772         # executor thread.  Be careful what you do with it b/c it
773         # needs to get back and dump status again periodically.
774         with self.status.lock:
775             self.status.periodic_dump(self.total_bundles_submitted)
776
777             # Look for bundles to reschedule via executor.submit
778             if config.config['executors_schedule_remote_backups']:
779                 self._maybe_schedule_backup_bundles()
780
781     def _maybe_schedule_backup_bundles(self):
782         """Maybe schedule backup bundles if we see a very slow bundle."""
783
784         assert self.status.lock.locked()
785         num_done = len(self.status.finished_bundle_timings)
786         num_idle_workers = self.worker_count - self.task_count
787         now = time.time()
788         if (
789             num_done >= 2
790             and num_idle_workers > 0
791             and (self.last_backup is None or (now - self.last_backup > 9.0))
792             and self.backup_lock.acquire(blocking=False)
793         ):
794             try:
795                 assert self.backup_lock.locked()
796
797                 bundle_to_backup = None
798                 best_score = None
799                 for (
800                     worker,
801                     bundle_uuids,
802                 ) in self.status.in_flight_bundles_by_worker.items():
803
804                     # Prefer to schedule backups of bundles running on
805                     # slower machines.
806                     base_score = 0
807                     for record in self.workers:
808                         if worker.machine == record.machine:
809                             base_score = float(record.weight)
810                             base_score = 1.0 / base_score
811                             base_score *= 200.0
812                             base_score = int(base_score)
813                             break
814
815                     for uuid in bundle_uuids:
816                         bundle = self.status.bundle_details_by_uuid.get(uuid, None)
817                         if (
818                             bundle is not None
819                             and bundle.src_bundle is None
820                             and bundle.backup_bundles is not None
821                         ):
822                             score = base_score
823
824                             # Schedule backups of bundles running
825                             # longer; especially those that are
826                             # unexpectedly slow.
827                             start_ts = self.status.start_per_bundle[uuid]
828                             if start_ts is not None:
829                                 runtime = now - start_ts
830                                 score += runtime
831                                 logger.debug(
832                                     'score[%s] => %.1f  # latency boost', bundle, score
833                                 )
834
835                                 if bundle.slower_than_local_p95:
836                                     score += runtime / 2
837                                     logger.debug(
838                                         'score[%s] => %.1f  # >worker p95',
839                                         bundle,
840                                         score,
841                                     )
842
843                                 if bundle.slower_than_global_p95:
844                                     score += runtime / 4
845                                     logger.debug(
846                                         'score[%s] => %.1f  # >global p95',
847                                         bundle,
848                                         score,
849                                     )
850
851                             # Prefer backups of bundles that don't
852                             # have backups already.
853                             backup_count = len(bundle.backup_bundles)
854                             if backup_count == 0:
855                                 score *= 2
856                             elif backup_count == 1:
857                                 score /= 2
858                             elif backup_count == 2:
859                                 score /= 8
860                             else:
861                                 score = 0
862                             logger.debug(
863                                 'score[%s] => %.1f  # {backup_count} dup backup factor',
864                                 bundle,
865                                 score,
866                             )
867
868                             if score != 0 and (
869                                 best_score is None or score > best_score
870                             ):
871                                 bundle_to_backup = bundle
872                                 assert bundle is not None
873                                 assert bundle.backup_bundles is not None
874                                 assert bundle.src_bundle is None
875                                 best_score = score
876
877                 # Note: this is all still happening on the heartbeat
878                 # runner thread.  That's ok because
879                 # _schedule_backup_for_bundle uses the executor to
880                 # submit the bundle again which will cause it to be
881                 # picked up by a worker thread and allow this thread
882                 # to return to run future heartbeats.
883                 if bundle_to_backup is not None:
884                     self.last_backup = now
885                     logger.info(
886                         '=====> SCHEDULING BACKUP %s (score=%.1f) <=====',
887                         bundle_to_backup,
888                         best_score,
889                     )
890                     self._schedule_backup_for_bundle(bundle_to_backup)
891             finally:
892                 self.backup_lock.release()
893
894     def _is_worker_available(self) -> bool:
895         """Is there a worker available currently?"""
896         return self.policy.is_worker_available()
897
898     def _acquire_worker(
899         self, machine_to_avoid: str = None
900     ) -> Optional[RemoteWorkerRecord]:
901         """Try to acquire a worker."""
902         return self.policy.acquire_worker(machine_to_avoid)
903
904     def _find_available_worker_or_block(
905         self, machine_to_avoid: str = None
906     ) -> RemoteWorkerRecord:
907         """Find a worker or block until one becomes available."""
908         with self.cv:
909             while not self._is_worker_available():
910                 self.cv.wait()
911             worker = self._acquire_worker(machine_to_avoid)
912             if worker is not None:
913                 return worker
914         msg = "We should never reach this point in the code"
915         logger.critical(msg)
916         raise Exception(msg)
917
918     def _release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None:
919         """Release a previously acquired worker."""
920         worker = bundle.worker
921         assert worker is not None
922         logger.debug('Released worker %s', worker)
923         self.status.record_release_worker(
924             worker,
925             bundle.uuid,
926             was_cancelled,
927         )
928         with self.cv:
929             worker.count += 1
930             self.cv.notify()
931         self.adjust_task_count(-1)
932
933     def _check_if_cancelled(self, bundle: BundleDetails) -> bool:
934         """See if a particular bundle is cancelled.  Do not block."""
935         with self.status.lock:
936             if bundle.is_cancelled.wait(timeout=0.0):
937                 logger.debug('Bundle %s is cancelled, bail out.', bundle.uuid)
938                 bundle.was_cancelled = True
939                 return True
940         return False
941
942     def _launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
943         """Find a worker for bundle or block until one is available."""
944
945         self.adjust_task_count(+1)
946         uuid = bundle.uuid
947         hostname = bundle.hostname
948         avoid_machine = override_avoid_machine
949         is_original = bundle.src_bundle is None
950
951         # Try not to schedule a backup on the same host as the original.
952         if avoid_machine is None and bundle.src_bundle is not None:
953             avoid_machine = bundle.src_bundle.machine
954         worker = None
955         while worker is None:
956             worker = self._find_available_worker_or_block(avoid_machine)
957         assert worker is not None
958
959         # Ok, found a worker.
960         bundle.worker = worker
961         machine = bundle.machine = worker.machine
962         username = bundle.username = worker.username
963         self.status.record_acquire_worker(worker, uuid)
964         logger.debug('%s: Running bundle on %s...', bundle, worker)
965
966         # Before we do any work, make sure the bundle is still viable.
967         # It may have been some time between when it was submitted and
968         # now due to lack of worker availability and someone else may
969         # have already finished it.
970         if self._check_if_cancelled(bundle):
971             try:
972                 return self._process_work_result(bundle)
973             except Exception as e:
974                 logger.warning(
975                     '%s: bundle says it\'s cancelled upfront but no results?!', bundle
976                 )
977                 self._release_worker(bundle)
978                 if is_original:
979                     # Weird.  We are the original owner of this
980                     # bundle.  For it to have been cancelled, a backup
981                     # must have already started and completed before
982                     # we even for started.  Moreover, the backup says
983                     # it is done but we can't find the results it
984                     # should have copied over.  Reschedule the whole
985                     # thing.
986                     logger.exception(e)
987                     logger.error(
988                         '%s: We are the original owner thread and yet there are '
989                         'no results for this bundle.  This is unexpected and bad.',
990                         bundle,
991                     )
992                     return self._emergency_retry_nasty_bundle(bundle)
993                 else:
994                     # We're a backup and our bundle is cancelled
995                     # before we even got started.  Do nothing and let
996                     # the original bundle's thread worry about either
997                     # finding the results or complaining about it.
998                     return None
999
1000         # Send input code / data to worker machine if it's not local.
1001         if hostname not in machine:
1002             try:
1003                 cmd = (
1004                     f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
1005                 )
1006                 start_ts = time.time()
1007                 logger.info("%s: Copying work to %s via %s.", bundle, worker, cmd)
1008                 run_silently(cmd)
1009                 xfer_latency = time.time() - start_ts
1010                 logger.debug(
1011                     "%s: Copying to %s took %.1fs.", bundle, worker, xfer_latency
1012                 )
1013             except Exception as e:
1014                 self._release_worker(bundle)
1015                 if is_original:
1016                     # Weird.  We tried to copy the code to the worker
1017                     # and it failed...  And we're the original bundle.
1018                     # We have to retry.
1019                     logger.exception(e)
1020                     logger.error(
1021                         "%s: Failed to send instructions to the worker machine?! "
1022                         "This is not expected; we\'re the original bundle so this shouldn\'t "
1023                         "be a race condition.  Attempting an emergency retry...",
1024                         bundle,
1025                     )
1026                     return self._emergency_retry_nasty_bundle(bundle)
1027                 else:
1028                     # This is actually expected; we're a backup.
1029                     # There's a race condition where someone else
1030                     # already finished the work and removed the source
1031                     # code_file before we could copy it.  Ignore.
1032                     logger.warning(
1033                         '%s: Failed to send instructions to the worker machine... '
1034                         'We\'re a backup and this may be caused by the original (or '
1035                         'some other backup) already finishing this work.  Ignoring.',
1036                         bundle,
1037                     )
1038                     return None
1039
1040         # Kick off the work.  Note that if this fails we let
1041         # _wait_for_process deal with it.
1042         self.status.record_processing_began(uuid)
1043         cmd = (
1044             f'{SSH} {bundle.username}@{bundle.machine} '
1045             f'"source py39-venv/bin/activate &&'
1046             f' /home/scott/lib/python_modules/remote_worker.py'
1047             f' --code_file {bundle.code_file} --result_file {bundle.result_file}"'
1048         )
1049         logger.debug(
1050             '%s: Executing %s in the background to kick off work...', bundle, cmd
1051         )
1052         p = cmd_in_background(cmd, silent=True)
1053         bundle.pid = p.pid
1054         logger.debug(
1055             '%s: Local ssh process pid=%d; remote worker is %s.', bundle, p.pid, machine
1056         )
1057         return self._wait_for_process(p, bundle, 0)
1058
1059     def _wait_for_process(
1060         self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
1061     ) -> Any:
1062         """At this point we've copied the bundle's pickled code to the remote
1063         worker and started an ssh process that should be invoking the
1064         remote worker to have it execute the user's code.  See how
1065         that's going and wait for it to complete or fail.  Note that
1066         this code is recursive: there are codepaths where we decide to
1067         stop waiting for an ssh process (because another backup seems
1068         to have finished) but then fail to fetch or parse the results
1069         from that backup and thus call ourselves to continue waiting
1070         on an active ssh process.  This is the purpose of the depth
1071         argument: to curtail potential infinite recursion by giving up
1072         eventually.
1073
1074         Args:
1075             p: the Popen record of the ssh job
1076             bundle: the bundle of work being executed remotely
1077             depth: how many retries we've made so far.  Starts at zero.
1078
1079         """
1080
1081         machine = bundle.machine
1082         assert p is not None
1083         pid = p.pid  # pid of the ssh process
1084         if depth > 3:
1085             logger.error(
1086                 "I've gotten repeated errors waiting on this bundle; giving up on pid=%d",
1087                 pid,
1088             )
1089             p.terminate()
1090             self._release_worker(bundle)
1091             return self._emergency_retry_nasty_bundle(bundle)
1092
1093         # Spin until either the ssh job we scheduled finishes the
1094         # bundle or some backup worker signals that they finished it
1095         # before we could.
1096         while True:
1097             try:
1098                 p.wait(timeout=0.25)
1099             except subprocess.TimeoutExpired:
1100                 if self._check_if_cancelled(bundle):
1101                     logger.info(
1102                         '%s: looks like another worker finished bundle...', bundle
1103                     )
1104                     break
1105             else:
1106                 logger.info("%s: pid %d (%s) is finished!", bundle, pid, machine)
1107                 p = None
1108                 break
1109
1110         # If we get here we believe the bundle is done; either the ssh
1111         # subprocess finished (hopefully successfully) or we noticed
1112         # that some other worker seems to have completed the bundle
1113         # before us and we're bailing out.
1114         try:
1115             ret = self._process_work_result(bundle)
1116             if ret is not None and p is not None:
1117                 p.terminate()
1118             return ret
1119
1120         # Something went wrong; e.g. we could not copy the results
1121         # back, cleanup after ourselves on the remote machine, or
1122         # unpickle the results we got from the remove machine.  If we
1123         # still have an active ssh subprocess, keep waiting on it.
1124         # Otherwise, time for an emergency reschedule.
1125         except Exception as e:
1126             logger.exception(e)
1127             logger.error('%s: Something unexpected just happened...', bundle)
1128             if p is not None:
1129                 logger.warning(
1130                     "%s: Failed to wrap up \"done\" bundle, re-waiting on active ssh.",
1131                     bundle,
1132                 )
1133                 return self._wait_for_process(p, bundle, depth + 1)
1134             else:
1135                 self._release_worker(bundle)
1136                 return self._emergency_retry_nasty_bundle(bundle)
1137
1138     def _process_work_result(self, bundle: BundleDetails) -> Any:
1139         """A bundle seems to be completed.  Check on the results."""
1140
1141         with self.status.lock:
1142             is_original = bundle.src_bundle is None
1143             was_cancelled = bundle.was_cancelled
1144             username = bundle.username
1145             machine = bundle.machine
1146             result_file = bundle.result_file
1147             code_file = bundle.code_file
1148
1149             # Whether original or backup, if we finished first we must
1150             # fetch the results if the computation happened on a
1151             # remote machine.
1152             bundle.end_ts = time.time()
1153             if not was_cancelled:
1154                 assert bundle.machine is not None
1155                 if bundle.hostname not in bundle.machine:
1156                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
1157                     logger.info(
1158                         "%s: Fetching results back from %s@%s via %s",
1159                         bundle,
1160                         username,
1161                         machine,
1162                         cmd,
1163                     )
1164
1165                     # If either of these throw they are handled in
1166                     # _wait_for_process.
1167                     attempts = 0
1168                     while True:
1169                         try:
1170                             run_silently(cmd)
1171                         except Exception as e:
1172                             attempts += 1
1173                             if attempts >= 3:
1174                                 raise e
1175                         else:
1176                             break
1177
1178                     # Cleanup remote /tmp files.
1179                     run_silently(
1180                         f'{SSH} {username}@{machine}'
1181                         f' "/bin/rm -f {code_file} {result_file}"'
1182                     )
1183                     logger.debug(
1184                         'Fetching results back took %.2fs', time.time() - bundle.end_ts
1185                     )
1186                 dur = bundle.end_ts - bundle.start_ts
1187                 self.histogram.add_item(dur)
1188
1189         # Only the original worker should unpickle the file contents
1190         # though since it's the only one whose result matters.  The
1191         # original is also the only job that may delete result_file
1192         # from disk.  Note that the original may have been cancelled
1193         # if one of the backups finished first; it still must read the
1194         # result from disk.  It still does that here with is_cancelled
1195         # set.
1196         if is_original:
1197             logger.debug("%s: Unpickling %s.", bundle, result_file)
1198             try:
1199                 with open(result_file, 'rb') as rb:
1200                     serialized = rb.read()
1201                 result = cloudpickle.loads(serialized)
1202             except Exception as e:
1203                 logger.exception(e)
1204                 logger.error('Failed to load %s... this is bad news.', result_file)
1205                 self._release_worker(bundle)
1206
1207                 # Re-raise the exception; the code in _wait_for_process may
1208                 # decide to _emergency_retry_nasty_bundle here.
1209                 raise e
1210             logger.debug('Removing local (master) %s and %s.', code_file, result_file)
1211             os.remove(result_file)
1212             os.remove(code_file)
1213
1214             # Notify any backups that the original is done so they
1215             # should stop ASAP.  Do this whether or not we
1216             # finished first since there could be more than one
1217             # backup.
1218             if bundle.backup_bundles is not None:
1219                 for backup in bundle.backup_bundles:
1220                     logger.debug(
1221                         '%s: Notifying backup %s that it\'s cancelled',
1222                         bundle,
1223                         backup.uuid,
1224                     )
1225                     backup.is_cancelled.set()
1226
1227         # This is a backup job and, by now, we have already fetched
1228         # the bundle results.
1229         else:
1230             # Backup results don't matter, they just need to leave the
1231             # result file in the right place for their originals to
1232             # read/unpickle later.
1233             result = None
1234
1235             # Tell the original to stop if we finished first.
1236             if not was_cancelled:
1237                 orig_bundle = bundle.src_bundle
1238                 assert orig_bundle is not None
1239                 logger.debug(
1240                     '%s: Notifying original %s we beat them to it.',
1241                     bundle,
1242                     orig_bundle.uuid,
1243                 )
1244                 orig_bundle.is_cancelled.set()
1245         self._release_worker(bundle, was_cancelled=was_cancelled)
1246         return result
1247
1248     def _create_original_bundle(self, pickle, function_name: str):
1249         """Creates a bundle that is not a backup of any other bundle but
1250         rather represents a user task.
1251         """
1252
1253         uuid = string_utils.generate_uuid(omit_dashes=True)
1254         code_file = f'/tmp/{uuid}.code.bin'
1255         result_file = f'/tmp/{uuid}.result.bin'
1256
1257         logger.debug('Writing pickled code to %s', code_file)
1258         with open(code_file, 'wb') as wb:
1259             wb.write(pickle)
1260
1261         bundle = BundleDetails(
1262             pickled_code=pickle,
1263             uuid=uuid,
1264             function_name=function_name,
1265             worker=None,
1266             username=None,
1267             machine=None,
1268             hostname=platform.node(),
1269             code_file=code_file,
1270             result_file=result_file,
1271             pid=0,
1272             start_ts=time.time(),
1273             end_ts=0.0,
1274             slower_than_local_p95=False,
1275             slower_than_global_p95=False,
1276             src_bundle=None,
1277             is_cancelled=threading.Event(),
1278             was_cancelled=False,
1279             backup_bundles=[],
1280             failure_count=0,
1281         )
1282         self.status.record_bundle_details(bundle)
1283         logger.debug('%s: Created an original bundle', bundle)
1284         return bundle
1285
1286     def _create_backup_bundle(self, src_bundle: BundleDetails):
1287         """Creates a bundle that is a backup of another bundle that is
1288         running too slowly."""
1289
1290         assert self.status.lock.locked()
1291         assert src_bundle.backup_bundles is not None
1292         n = len(src_bundle.backup_bundles)
1293         uuid = src_bundle.uuid + f'_backup#{n}'
1294
1295         backup_bundle = BundleDetails(
1296             pickled_code=src_bundle.pickled_code,
1297             uuid=uuid,
1298             function_name=src_bundle.function_name,
1299             worker=None,
1300             username=None,
1301             machine=None,
1302             hostname=src_bundle.hostname,
1303             code_file=src_bundle.code_file,
1304             result_file=src_bundle.result_file,
1305             pid=0,
1306             start_ts=time.time(),
1307             end_ts=0.0,
1308             slower_than_local_p95=False,
1309             slower_than_global_p95=False,
1310             src_bundle=src_bundle,
1311             is_cancelled=threading.Event(),
1312             was_cancelled=False,
1313             backup_bundles=None,  # backup backups not allowed
1314             failure_count=0,
1315         )
1316         src_bundle.backup_bundles.append(backup_bundle)
1317         self.status.record_bundle_details_already_locked(backup_bundle)
1318         logger.debug('%s: Created a backup bundle', backup_bundle)
1319         return backup_bundle
1320
1321     def _schedule_backup_for_bundle(self, src_bundle: BundleDetails):
1322         """Schedule a backup of src_bundle."""
1323
1324         assert self.status.lock.locked()
1325         assert src_bundle is not None
1326         backup_bundle = self._create_backup_bundle(src_bundle)
1327         logger.debug(
1328             '%s/%s: Scheduling backup for execution...',
1329             backup_bundle.uuid,
1330             backup_bundle.function_name,
1331         )
1332         self._helper_executor.submit(self._launch, backup_bundle)
1333
1334         # Results from backups don't matter; if they finish first
1335         # they will move the result_file to this machine and let
1336         # the original pick them up and unpickle them (and return
1337         # a result).
1338
1339     def _emergency_retry_nasty_bundle(
1340         self, bundle: BundleDetails
1341     ) -> Optional[fut.Future]:
1342         """Something unexpectedly failed with bundle.  Either retry it
1343         from the beginning or throw in the towel and give up on it."""
1344
1345         is_original = bundle.src_bundle is None
1346         bundle.worker = None
1347         avoid_last_machine = bundle.machine
1348         bundle.machine = None
1349         bundle.username = None
1350         bundle.failure_count += 1
1351         if is_original:
1352             retry_limit = 3
1353         else:
1354             retry_limit = 2
1355
1356         if bundle.failure_count > retry_limit:
1357             logger.error(
1358                 '%s: Tried this bundle too many times already (%dx); giving up.',
1359                 bundle,
1360                 retry_limit,
1361             )
1362             if is_original:
1363                 raise RemoteExecutorException(
1364                     f'{bundle}: This bundle can\'t be completed despite several backups and retries',
1365                 )
1366             else:
1367                 logger.error(
1368                     '%s: At least it\'s only a backup; better luck with the others.',
1369                     bundle,
1370                 )
1371             return None
1372         else:
1373             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1374             logger.warning(msg)
1375             warnings.warn(msg)
1376             return self._launch(bundle, avoid_last_machine)
1377
1378     @overrides
1379     def submit(self, function: Callable, *args, **kwargs) -> fut.Future:
1380         """Submit work to be done.  This is the user entry point of this
1381         class."""
1382         if self.already_shutdown:
1383             raise Exception('Submitted work after shutdown.')
1384         pickle = _make_cloud_pickle(function, *args, **kwargs)
1385         bundle = self._create_original_bundle(pickle, function.__name__)
1386         self.total_bundles_submitted += 1
1387         return self._helper_executor.submit(self._launch, bundle)
1388
1389     @overrides
1390     def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None:
1391         """Shutdown the executor."""
1392         if not self.already_shutdown:
1393             logging.debug('Shutting down RemoteExecutor %s', self.title)
1394             self.heartbeat_stop_event.set()
1395             self.heartbeat_thread.join()
1396             self._helper_executor.shutdown(wait)
1397             if not quiet:
1398                 print(self.histogram.__repr__(label_formatter='%ds'))
1399             self.already_shutdown = True
1400
1401
1402 class RemoteWorkerPoolProvider:
1403     @abstractmethod
1404     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
1405         pass
1406
1407
1408 @persistent.persistent_autoloaded_singleton()  # type: ignore
1409 class ConfigRemoteWorkerPoolProvider(
1410     RemoteWorkerPoolProvider, persistent.JsonFileBasedPersistent
1411 ):
1412     def __init__(self, json_remote_worker_pool: Dict[str, Any]):
1413         self.remote_worker_pool = []
1414         for record in json_remote_worker_pool['remote_worker_records']:
1415             self.remote_worker_pool.append(
1416                 self.dataclassFromDict(RemoteWorkerRecord, record)
1417             )
1418         assert len(self.remote_worker_pool) > 0
1419
1420     @staticmethod
1421     def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any:
1422         fieldSet = {f.name for f in fields(clsName) if f.init}
1423         filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet}
1424         return clsName(**filteredArgDict)
1425
1426     @overrides
1427     def get_remote_workers(self) -> List[RemoteWorkerRecord]:
1428         return self.remote_worker_pool
1429
1430     @overrides
1431     def get_persistent_data(self) -> List[RemoteWorkerRecord]:
1432         return self.remote_worker_pool
1433
1434     @staticmethod
1435     @overrides
1436     def get_filename() -> str:
1437         return config.config['remote_worker_records_file']
1438
1439     @staticmethod
1440     @overrides
1441     def should_we_load_data(filename: str) -> bool:
1442         return True
1443
1444     @staticmethod
1445     @overrides
1446     def should_we_save_data(filename: str) -> bool:
1447         return False
1448
1449
1450 @singleton
1451 class DefaultExecutors(object):
1452     """A container for a default thread, process and remote executor.
1453     These are not created until needed and we take care to clean up
1454     before process exit automatically for the caller's convenience.
1455     Instead of creating your own executor, consider using the one
1456     from this pool.  e.g.::
1457
1458         @par.parallelize(method=par.Method.PROCESS)
1459         def do_work(
1460             solutions: List[Work],
1461             shard_num: int,
1462             ...
1463         ):
1464             <do the work>
1465
1466
1467         def start_do_work(all_work: List[Work]):
1468             shards = []
1469             logger.debug('Sharding work into groups of 10.')
1470             for subset in list_utils.shard(all_work, 10):
1471                 shards.append([x for x in subset])
1472
1473             logger.debug('Kicking off helper pool.')
1474             try:
1475                 for n, shard in enumerate(shards):
1476                     results.append(
1477                         do_work(
1478                             shard, n, shared_cache.get_name(), max_letter_pop_per_word
1479                         )
1480                     )
1481                 smart_future.wait_all(results)
1482             finally:
1483                 # Note: if you forget to do this it will clean itself up
1484                 # during program termination including tearing down any
1485                 # active ssh connections.
1486                 executors.DefaultExecutors().process_pool().shutdown()
1487     """
1488
1489     def __init__(self):
1490         self.thread_executor: Optional[ThreadExecutor] = None
1491         self.process_executor: Optional[ProcessExecutor] = None
1492         self.remote_executor: Optional[RemoteExecutor] = None
1493
1494     @staticmethod
1495     def _ping(host) -> bool:
1496         logger.debug('RUN> ping -c 1 %s', host)
1497         try:
1498             x = cmd_exitcode(
1499                 f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0
1500             )
1501             return x == 0
1502         except Exception:
1503             return False
1504
1505     def thread_pool(self) -> ThreadExecutor:
1506         if self.thread_executor is None:
1507             self.thread_executor = ThreadExecutor()
1508         return self.thread_executor
1509
1510     def process_pool(self) -> ProcessExecutor:
1511         if self.process_executor is None:
1512             self.process_executor = ProcessExecutor()
1513         return self.process_executor
1514
1515     def remote_pool(self) -> RemoteExecutor:
1516         if self.remote_executor is None:
1517             logger.info('Looking for some helper machines...')
1518             provider = ConfigRemoteWorkerPoolProvider()
1519             all_machines = provider.get_remote_workers()
1520             pool = []
1521
1522             # Make sure we can ping each machine.
1523             for record in all_machines:
1524                 if self._ping(record.machine):
1525                     logger.info('%s is alive / responding to pings', record.machine)
1526                     pool.append(record)
1527
1528             # The controller machine has a lot to do; go easy on it.
1529             for record in pool:
1530                 if record.machine == platform.node() and record.count > 1:
1531                     logger.info('Reducing workload for %s.', record.machine)
1532                     record.count = max(int(record.count / 2), 1)
1533
1534             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1535             policy.register_worker_pool(pool)
1536             self.remote_executor = RemoteExecutor(pool, policy)
1537         return self.remote_executor
1538
1539     def shutdown(self) -> None:
1540         if self.thread_executor is not None:
1541             self.thread_executor.shutdown(wait=True, quiet=True)
1542             self.thread_executor = None
1543         if self.process_executor is not None:
1544             self.process_executor.shutdown(wait=True, quiet=True)
1545             self.process_executor = None
1546         if self.remote_executor is not None:
1547             self.remote_executor.shutdown(wait=True, quiet=True)
1548             self.remote_executor = None