3 from __future__ import annotations
5 from abc import ABC, abstractmethod
6 import concurrent.futures as fut
7 from collections import defaultdict
8 from dataclasses import dataclass
17 from typing import Any, Callable, Dict, List, Optional, Set
19 import cloudpickle # type: ignore
20 from overrides import overrides
22 from ansi import bg, fg, underline, reset
25 from exec_utils import run_silently, cmd_in_background
26 from decorator_utils import singleton
27 import histogram as hist
29 logger = logging.getLogger(__name__)
31 parser = config.add_commandline_args(
32 f"Executors ({__file__})",
33 "Args related to processing executors."
36 '--executors_threadpool_size',
39 help='Number of threads in the default threadpool, leave unset for default',
43 '--executors_processpool_size',
46 help='Number of processes in the default processpool, leave unset for default',
50 '--executors_schedule_remote_backups',
52 action=argparse_utils.ActionNoYes,
53 help='Should we schedule duplicative backup work if a remote bundle is slow',
56 '--executors_max_bundle_failures',
60 help='Maximum number of failures before giving up on a bundle',
63 RSYNC = 'rsync -q --no-motd -W --ignore-existing --timeout=60 --size-only -z'
64 SSH = 'ssh -oForwardX11=no'
67 def make_cloud_pickle(fun, *args, **kwargs):
68 logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
69 return cloudpickle.dumps((fun, args, kwargs))
72 class BaseExecutor(ABC):
73 def __init__(self, *, title=''):
76 self.histogram = hist.SimpleHistogram(
77 hist.SimpleHistogram.n_evenly_spaced_buckets(
86 **kwargs) -> fut.Future:
91 wait: bool = True) -> None:
94 def adjust_task_count(self, delta: int) -> None:
95 self.task_count += delta
96 logger.debug(f'Executor current task count is {self.task_count}')
99 class ThreadExecutor(BaseExecutor):
101 max_workers: Optional[int] = None):
104 if max_workers is not None:
105 workers = max_workers
106 elif 'executors_threadpool_size' in config.config:
107 workers = config.config['executors_threadpool_size']
108 logger.debug(f'Creating threadpool executor with {workers} workers')
109 self._thread_pool_executor = fut.ThreadPoolExecutor(
111 thread_name_prefix="thread_executor_helper"
114 def run_local_bundle(self, fun, *args, **kwargs):
115 logger.debug(f"Running local bundle at {fun.__name__}")
117 result = fun(*args, **kwargs)
119 self.adjust_task_count(-1)
120 duration = end - start
121 logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
122 self.histogram.add_item(duration)
129 **kwargs) -> fut.Future:
130 self.adjust_task_count(+1)
132 newargs.append(function)
135 return self._thread_pool_executor.submit(
136 self.run_local_bundle,
142 wait = True) -> None:
143 logger.debug(f'Shutting down threadpool executor {self.title}')
144 print(self.histogram)
145 self._thread_pool_executor.shutdown(wait)
148 class ProcessExecutor(BaseExecutor):
153 if max_workers is not None:
154 workers = max_workers
155 elif 'executors_processpool_size' in config.config:
156 workers = config.config['executors_processpool_size']
157 logger.debug(f'Creating processpool executor with {workers} workers.')
158 self._process_executor = fut.ProcessPoolExecutor(
162 def run_cloud_pickle(self, pickle):
163 fun, args, kwargs = cloudpickle.loads(pickle)
164 logger.debug(f"Running pickled bundle at {fun.__name__}")
165 result = fun(*args, **kwargs)
166 self.adjust_task_count(-1)
173 **kwargs) -> fut.Future:
175 self.adjust_task_count(+1)
176 pickle = make_cloud_pickle(function, *args, **kwargs)
177 result = self._process_executor.submit(
178 self.run_cloud_pickle,
181 result.add_done_callback(
182 lambda _: self.histogram.add_item(
189 def shutdown(self, wait=True) -> None:
190 logger.debug(f'Shutting down processpool executor {self.title}')
191 self._process_executor.shutdown(wait)
192 print(self.histogram)
194 def __getstate__(self):
195 state = self.__dict__.copy()
196 state['_process_executor'] = None
201 class RemoteWorkerRecord:
208 return hash((self.username, self.machine))
211 return f'{self.username}@{self.machine}'
219 worker: Optional[RemoteWorkerRecord]
220 username: Optional[str]
221 machine: Optional[str]
230 src_bundle: BundleDetails
231 is_cancelled: threading.Event
233 backup_bundles: Optional[List[BundleDetails]]
237 class RemoteExecutorStatus:
238 def __init__(self, total_worker_count: int) -> None:
239 self.worker_count = total_worker_count
240 self.known_workers: Set[RemoteWorkerRecord] = set()
241 self.start_per_bundle: Dict[str, float] = defaultdict(float)
242 self.end_per_bundle: Dict[str, float] = defaultdict(float)
243 self.finished_bundle_timings_per_worker: Dict[
247 self.in_flight_bundles_by_worker: Dict[
251 self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
252 self.finished_bundle_timings: List[float] = []
253 self.last_periodic_dump: Optional[float] = None
254 self.total_bundles_submitted = 0
256 # Protects reads and modification using self. Also used
257 # as a memory fence for modifications to bundle.
258 self.lock = threading.Lock()
260 def record_acquire_worker(
262 worker: RemoteWorkerRecord,
266 self.record_acquire_worker_already_locked(
271 def record_acquire_worker_already_locked(
273 worker: RemoteWorkerRecord,
276 assert self.lock.locked()
277 self.known_workers.add(worker)
278 self.start_per_bundle[uuid] = time.time()
279 x = self.in_flight_bundles_by_worker.get(worker, set())
281 self.in_flight_bundles_by_worker[worker] = x
283 def record_bundle_details(
285 details: BundleDetails) -> None:
287 self.record_bundle_details_already_locked(details)
289 def record_bundle_details_already_locked(
291 details: BundleDetails) -> None:
292 assert self.lock.locked()
293 self.bundle_details_by_uuid[details.uuid] = details
295 def record_release_worker_already_locked(
297 worker: RemoteWorkerRecord,
301 assert self.lock.locked()
303 self.end_per_bundle[uuid] = ts
304 self.in_flight_bundles_by_worker[worker].remove(uuid)
305 if not was_cancelled:
306 bundle_latency = ts - self.start_per_bundle[uuid]
307 x = self.finished_bundle_timings_per_worker.get(worker, list())
308 x.append(bundle_latency)
309 self.finished_bundle_timings_per_worker[worker] = x
310 self.finished_bundle_timings.append(bundle_latency)
312 def total_in_flight(self) -> int:
313 assert self.lock.locked()
315 for worker in self.known_workers:
316 total_in_flight += len(self.in_flight_bundles_by_worker[worker])
317 return total_in_flight
319 def total_idle(self) -> int:
320 assert self.lock.locked()
321 return self.worker_count - self.total_in_flight()
324 assert self.lock.locked()
326 total_finished = len(self.finished_bundle_timings)
327 total_in_flight = self.total_in_flight()
328 ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
330 if len(self.finished_bundle_timings) > 1:
331 qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
333 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, '
334 f'✅={total_finished}/{self.total_bundles_submitted}, '
335 f'💻n={total_in_flight}/{self.worker_count}\n'
339 f' ✅={total_finished}/{self.total_bundles_submitted}, '
340 f'💻n={total_in_flight}/{self.worker_count}\n'
343 for worker in self.known_workers:
344 ret += f' {fg("lightning yellow")}{worker.machine}{reset()}: '
345 timings = self.finished_bundle_timings_per_worker.get(worker, [])
349 qworker = numpy.quantile(timings, [0.5, 0.95])
350 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
354 ret += f' ...finished {count} total bundle(s) so far\n'
355 in_flight = len(self.in_flight_bundles_by_worker[worker])
357 ret += f' ...{in_flight} bundles currently in flight:\n'
358 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
359 details = self.bundle_details_by_uuid.get(
363 pid = str(details.pid) if details is not None else "TBD"
364 sec = ts - self.start_per_bundle[bundle_uuid]
365 ret += f' (pid={pid}): {bundle_uuid} for {sec:.1f}s so far '
366 if qworker is not None:
368 ret += f'{bg("red")}>💻p95{reset()} '
369 elif sec > qworker[0]:
370 ret += f'{fg("red")}>💻p50{reset()} '
372 if sec > qall[1] * 1.5:
373 ret += f'{bg("red")}!!!{reset()}'
374 if details is not None:
375 logger.debug(f'Flagging {details.uuid} for another backup')
376 details.super_slow = True
378 ret += f'{bg("red")}>∀p95{reset()} '
379 if details is not None:
380 logger.debug(f'Flagging {details.uuid} for a backup')
381 details.too_slow = True
383 ret += f'{fg("red")}>∀p50{reset()}'
387 def periodic_dump(self, total_bundles_submitted: int) -> None:
388 assert self.lock.locked()
389 self.total_bundles_submitted = total_bundles_submitted
392 self.last_periodic_dump is None
393 or ts - self.last_periodic_dump > 5.0
396 self.last_periodic_dump = ts
399 class RemoteWorkerSelectionPolicy(ABC):
400 def register_worker_pool(self, workers):
402 self.workers = workers
405 def is_worker_available(self) -> bool:
411 machine_to_avoid = None
412 ) -> Optional[RemoteWorkerRecord]:
416 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
417 def is_worker_available(self) -> bool:
418 for worker in self.workers:
425 machine_to_avoid = None
426 ) -> Optional[RemoteWorkerRecord]:
428 for worker in self.workers:
429 for x in range(0, worker.count):
430 for y in range(0, worker.weight):
431 grabbag.append(worker)
433 for _ in range(0, 5):
434 random.shuffle(grabbag)
436 if worker.machine != machine_to_avoid or _ > 2:
439 logger.debug(f'Selected worker {worker}')
441 logger.warning("Couldn't find a worker; go fish.")
445 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
446 def __init__(self) -> None:
449 def is_worker_available(self) -> bool:
450 for worker in self.workers:
457 machine_to_avoid: str = None
458 ) -> Optional[RemoteWorkerRecord]:
461 worker = self.workers[x]
465 if x >= len(self.workers):
468 logger.debug(f'Selected worker {worker}')
471 if x >= len(self.workers):
474 logger.warning("Couldn't find a worker; go fish.")
478 class RemoteExecutor(BaseExecutor):
480 workers: List[RemoteWorkerRecord],
481 policy: RemoteWorkerSelectionPolicy) -> None:
483 self.workers = workers
485 self.worker_count = 0
486 for worker in self.workers:
487 self.worker_count += worker.count
488 if self.worker_count <= 0:
489 msg = f"We need somewhere to schedule work; count was {self.worker_count}"
492 self.policy.register_worker_pool(self.workers)
493 self.cv = threading.Condition()
494 self._helper_executor = fut.ThreadPoolExecutor(
495 thread_name_prefix="remote_executor_helper",
496 max_workers=self.worker_count,
498 self.status = RemoteExecutorStatus(self.worker_count)
499 self.total_bundles_submitted = 0
501 f'Creating remote processpool with {self.worker_count} remote worker threads.'
504 def is_worker_available(self) -> bool:
505 return self.policy.is_worker_available()
509 machine_to_avoid: str = None
510 ) -> Optional[RemoteWorkerRecord]:
511 return self.policy.acquire_worker(machine_to_avoid)
513 def find_available_worker_or_block(
515 machine_to_avoid: str = None
516 ) -> RemoteWorkerRecord:
518 while not self.is_worker_available():
520 worker = self.acquire_worker(machine_to_avoid)
521 if worker is not None:
523 msg = "We should never reach this point in the code"
527 def release_worker(self, worker: RemoteWorkerRecord) -> None:
528 logger.debug(f'Released worker {worker}')
533 def heartbeat(self) -> None:
534 with self.status.lock:
535 # Regular progress report
536 self.status.periodic_dump(self.total_bundles_submitted)
538 # Look for bundles to reschedule
539 if len(self.status.finished_bundle_timings) > 7:
540 for worker, bundle_uuids in self.status.in_flight_bundles_by_worker.items():
541 for uuid in bundle_uuids:
542 bundle = self.status.bundle_details_by_uuid.get(uuid, None)
544 bundle is not None and
546 bundle.src_bundle is None and
547 config.config['executors_schedule_remote_backups']
549 self.consider_backup_for_bundle(bundle)
551 def consider_backup_for_bundle(self, bundle: BundleDetails) -> None:
552 assert self.status.lock.locked()
555 and len(bundle.backup_bundles) == 0 # one backup per
557 msg = f"*** Rescheduling {bundle.pid}/{bundle.uuid} ***"
559 self.schedule_backup_for_bundle(bundle)
563 and len(bundle.backup_bundles) < 2 # two backups in dire situations
564 and self.status.total_idle() > 4
566 msg = f"*** Rescheduling {bundle.pid}/{bundle.uuid} ***"
568 self.schedule_backup_for_bundle(bundle)
571 def check_if_cancelled(self, bundle: BundleDetails) -> bool:
572 with self.status.lock:
573 if bundle.is_cancelled.wait(timeout=0.0):
574 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
575 bundle.was_cancelled = True
579 def launch(self, bundle: BundleDetails) -> Any:
580 """Find a worker for bundle or block until one is available."""
581 self.adjust_task_count(+1)
583 hostname = bundle.hostname
586 # Try not to schedule a backup on the same host as the original.
587 if bundle.src_bundle is not None:
588 avoid_machine = bundle.src_bundle.machine
590 while worker is None:
591 worker = self.find_available_worker_or_block(avoid_machine)
592 bundle.worker = worker
593 machine = bundle.machine = worker.machine
594 username = bundle.username = worker.username
596 self.status.record_acquire_worker(worker, uuid)
597 logger.debug(f'Running bundle {uuid} on {worker}...')
599 # Before we do any work, make sure the bundle is still viable.
600 if self.check_if_cancelled(bundle):
602 return self.post_launch_work(bundle)
603 except Exception as e:
605 logger.info(f"{uuid}/{fname}: bundle seems to have failed?!")
606 if bundle.failure_count < config.config['executors_max_bundle_failures']:
607 return self.launch(bundle)
609 logger.info(f"{uuid}/{fname}: bundle is poison, giving up on it.")
612 # Send input to machine if it's not local.
613 if hostname not in machine:
614 cmd = f'{RSYNC} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
615 logger.info(f"{uuid}/{fname}: Copying work to {worker} via {cmd}")
619 cmd = (f'{SSH} {bundle.username}@{bundle.machine} '
620 f'"source py39-venv/bin/activate &&'
621 f' /home/scott/lib/python_modules/remote_worker.py'
622 f' --code_file {bundle.code_file} --result_file {bundle.result_file}"')
623 p = cmd_in_background(cmd, silent=True)
624 bundle.pid = pid = p.pid
625 logger.info(f"{uuid}/{fname}: Start training on {worker} via {cmd} (background pid {pid})")
630 except subprocess.TimeoutExpired:
633 # Both source and backup bundles can be cancelled by
634 # the other depending on which finishes first.
635 if self.check_if_cancelled(bundle):
640 f"{uuid}/{fname}: pid {pid} has finished its work normally."
645 return self.post_launch_work(bundle)
646 except Exception as e:
648 logger.info(f"{uuid}: Bundle seems to have failed?!")
649 if bundle.failure_count < config.config['executors_max_bundle_failures']:
650 return self.launch(bundle)
651 logger.info(f"{uuid}: Bundle is poison, giving up on it.")
654 def post_launch_work(self, bundle: BundleDetails) -> Any:
655 with self.status.lock:
656 is_original = bundle.src_bundle is None
657 was_cancelled = bundle.was_cancelled
658 username = bundle.username
659 machine = bundle.machine
660 result_file = bundle.result_file
661 code_file = bundle.code_file
665 # Whether original or backup, if we finished first we must
666 # fetch the results if the computation happened on a
668 bundle.end_ts = time.time()
669 if not was_cancelled:
670 assert bundle.machine is not None
671 if bundle.hostname not in bundle.machine:
672 cmd = f'{RSYNC} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
674 f"{uuid}/{fname}: Fetching results from {username}@{machine} via {cmd}"
678 except subprocess.CalledProcessError:
679 logger.critical(f'Failed to copy {username}@{machine}:{result_file}!')
680 run_silently(f'{SSH} {username}@{machine}'
681 f' "/bin/rm -f {code_file} {result_file}"')
682 dur = bundle.end_ts - bundle.start_ts
683 self.histogram.add_item(dur)
684 assert bundle.worker is not None
685 self.status.record_release_worker_already_locked(
691 # Only the original worker should unpickle the file contents
692 # though since it's the only one whose result matters. The
693 # original is also the only job that may delete result_file
694 # from disk. Note that the original may have been cancelled
695 # if one of the backups finished first; it still must read the
698 logger.debug(f"{uuid}/{fname}: Unpickling {result_file}.")
700 with open(f'{result_file}', 'rb') as rb:
701 serialized = rb.read()
702 result = cloudpickle.loads(serialized)
703 except Exception as e:
704 msg = f'Failed to load {result_file}'
706 bundle.failure_count += 1
707 self.release_worker(bundle.worker)
709 os.remove(f'{result_file}')
710 os.remove(f'{code_file}')
712 # Notify any backups that the original is done so they
713 # should stop ASAP. Do this whether or not we
714 # finished first since there could be more than one
716 if bundle.backup_bundles is not None:
717 for backup in bundle.backup_bundles:
719 f'{uuid}/{fname}: Notifying backup {backup.uuid} that it\'s cancelled'
721 backup.is_cancelled.set()
723 # This is a backup job.
725 # Backup results don't matter, they just need to leave the
726 # result file in the right place for their originals to
727 # read/unpickle later.
730 # Tell the original to stop if we finished first.
731 if not was_cancelled:
733 f'{uuid}/{fname}: Notifying original {bundle.src_bundle.uuid} that it\'s cancelled'
735 bundle.src_bundle.is_cancelled.set()
737 assert bundle.worker is not None
738 self.release_worker(bundle.worker)
739 self.adjust_task_count(-1)
742 def create_original_bundle(self, pickle, fname: str):
743 from string_utils import generate_uuid
744 uuid = generate_uuid(as_hex=True)
745 code_file = f'/tmp/{uuid}.code.bin'
746 result_file = f'/tmp/{uuid}.result.bin'
748 logger.debug(f'Writing pickled code to {code_file}')
749 with open(f'{code_file}', 'wb') as wb:
752 bundle = BundleDetails(
753 pickled_code = pickle,
759 hostname = platform.node(),
760 code_file = code_file,
761 result_file = result_file,
763 start_ts = time.time(),
768 is_cancelled = threading.Event(),
769 was_cancelled = False,
773 self.status.record_bundle_details(bundle)
774 logger.debug(f'{uuid}/{fname}: Created original bundle')
777 def create_backup_bundle(self, src_bundle: BundleDetails):
778 assert src_bundle.backup_bundles is not None
779 n = len(src_bundle.backup_bundles)
780 uuid = src_bundle.uuid + f'_backup#{n}'
782 backup_bundle = BundleDetails(
783 pickled_code = src_bundle.pickled_code,
785 fname = src_bundle.fname,
789 hostname = src_bundle.hostname,
790 code_file = src_bundle.code_file,
791 result_file = src_bundle.result_file,
793 start_ts = time.time(),
797 src_bundle = src_bundle,
798 is_cancelled = threading.Event(),
799 was_cancelled = False,
800 backup_bundles = None, # backup backups not allowed
803 src_bundle.backup_bundles.append(backup_bundle)
804 self.status.record_bundle_details_already_locked(backup_bundle)
805 logger.debug(f'{uuid}/{src_bundle.fname}: Created backup bundle')
808 def schedule_backup_for_bundle(self,
809 src_bundle: BundleDetails):
810 assert self.status.lock.locked()
811 backup_bundle = self.create_backup_bundle(src_bundle)
813 f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
815 self._helper_executor.submit(self.launch, backup_bundle)
817 # Results from backups don't matter; if they finish first
818 # they will move the result_file to this machine and let
819 # the original pick them up and unpickle them.
825 **kwargs) -> fut.Future:
826 pickle = make_cloud_pickle(function, *args, **kwargs)
827 bundle = self.create_original_bundle(pickle, function.__name__)
828 self.total_bundles_submitted += 1
829 return self._helper_executor.submit(self.launch, bundle)
832 def shutdown(self, wait=True) -> None:
833 self._helper_executor.shutdown(wait)
834 logging.debug(f'Shutting down RemoteExecutor {self.title}')
835 print(self.histogram)
839 class DefaultExecutors(object):
841 self.thread_executor: Optional[ThreadExecutor] = None
842 self.process_executor: Optional[ProcessExecutor] = None
843 self.remote_executor: Optional[RemoteExecutor] = None
845 def ping(self, host) -> bool:
846 logger.debug(f'RUN> ping -c 1 {host}')
847 command = ['ping', '-c', '1', host]
848 return subprocess.call(
850 stdout=subprocess.DEVNULL,
851 stderr=subprocess.DEVNULL,
854 def thread_pool(self) -> ThreadExecutor:
855 if self.thread_executor is None:
856 self.thread_executor = ThreadExecutor()
857 return self.thread_executor
859 def process_pool(self) -> ProcessExecutor:
860 if self.process_executor is None:
861 self.process_executor = ProcessExecutor()
862 return self.process_executor
864 def remote_pool(self) -> RemoteExecutor:
865 logger.info('Looking for some helper machines...')
866 if self.remote_executor is None:
867 pool: List[RemoteWorkerRecord] = []
868 if self.ping('cheetah.house'):
869 logger.info('Found cheetah.house')
873 machine = 'cheetah.house',
878 if self.ping('video.house'):
879 logger.info('Found video.house')
883 machine = 'video.house',
888 if self.ping('wannabe.house'):
889 logger.info('Found wannabe.house')
893 machine = 'wannabe.house',
898 if self.ping('meerkat.cabin'):
899 logger.info('Found meerkat.cabin')
903 machine = 'meerkat.cabin',
908 if self.ping('backup.house'):
909 logger.info('Found backup.house')
913 machine = 'backup.house',
918 if self.ping('kiosk.house'):
919 logger.info('Found kiosk.house')
923 machine = 'kiosk.house',
928 if self.ping('puma.cabin'):
929 logger.info('Found puma.cabin')
933 machine = 'puma.cabin',
939 # The controller machine has a lot to do; go easy on it.
941 if record.machine == platform.node() and record.count > 1:
942 logger.info(f'Reducing workload for {record.machine}.')
945 policy = WeightedRandomRemoteWorkerSelectionPolicy()
946 policy.register_worker_pool(pool)
947 self.remote_executor = RemoteExecutor(pool, policy)
948 return self.remote_executor