3 from __future__ import annotations
5 from abc import ABC, abstractmethod
6 import concurrent.futures as fut
7 from collections import defaultdict
8 from dataclasses import dataclass
17 from typing import Any, Callable, Dict, List, Optional, Set
19 import cloudpickle # type: ignore
21 from ansi import bg, fg, underline, reset
24 from exec_utils import run_silently, cmd_in_background
25 from decorator_utils import singleton
26 import histogram as hist
28 logger = logging.getLogger(__name__)
30 parser = config.add_commandline_args(
31 f"Executors ({__file__})",
32 "Args related to processing executors."
35 '--executors_threadpool_size',
38 help='Number of threads in the default threadpool, leave unset for default',
42 '--executors_processpool_size',
45 help='Number of processes in the default processpool, leave unset for default',
49 '--executors_schedule_remote_backups',
51 action=argparse_utils.ActionNoYes,
52 help='Should we schedule duplicative backup work if a remote bundle is slow',
55 '--executors_max_bundle_failures',
59 help='Maximum number of failures before giving up on a bundle',
62 RSYNC = 'rsync -q --no-motd -W --ignore-existing --timeout=60 --size-only -z'
63 SSH = 'ssh -oForwardX11=no'
66 def make_cloud_pickle(fun, *args, **kwargs):
67 logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
68 return cloudpickle.dumps((fun, args, kwargs))
71 class BaseExecutor(ABC):
72 def __init__(self, *, title=''):
75 self.histogram = hist.SimpleHistogram(
76 hist.SimpleHistogram.n_evenly_spaced_buckets(
85 **kwargs) -> fut.Future:
90 wait: bool = True) -> None:
93 def adjust_task_count(self, delta: int) -> None:
94 self.task_count += delta
95 logger.debug(f'Executor current task count is {self.task_count}')
98 class ThreadExecutor(BaseExecutor):
100 max_workers: Optional[int] = None):
103 if max_workers is not None:
104 workers = max_workers
105 elif 'executors_threadpool_size' in config.config:
106 workers = config.config['executors_threadpool_size']
107 logger.debug(f'Creating threadpool executor with {workers} workers')
108 self._thread_pool_executor = fut.ThreadPoolExecutor(
110 thread_name_prefix="thread_executor_helper"
113 def run_local_bundle(self, fun, *args, **kwargs):
114 logger.debug(f"Running local bundle at {fun.__name__}")
116 result = fun(*args, **kwargs)
118 self.adjust_task_count(-1)
119 duration = end - start
120 logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
121 self.histogram.add_item(duration)
127 **kwargs) -> fut.Future:
128 self.adjust_task_count(+1)
130 newargs.append(function)
133 return self._thread_pool_executor.submit(
134 self.run_local_bundle,
139 wait = True) -> None:
140 logger.debug(f'Shutting down threadpool executor {self.title}')
141 print(self.histogram)
142 self._thread_pool_executor.shutdown(wait)
145 class ProcessExecutor(BaseExecutor):
150 if max_workers is not None:
151 workers = max_workers
152 elif 'executors_processpool_size' in config.config:
153 workers = config.config['executors_processpool_size']
154 logger.debug(f'Creating processpool executor with {workers} workers.')
155 self._process_executor = fut.ProcessPoolExecutor(
159 def run_cloud_pickle(self, pickle):
160 fun, args, kwargs = cloudpickle.loads(pickle)
161 logger.debug(f"Running pickled bundle at {fun.__name__}")
162 result = fun(*args, **kwargs)
163 self.adjust_task_count(-1)
169 **kwargs) -> fut.Future:
171 self.adjust_task_count(+1)
172 pickle = make_cloud_pickle(function, *args, **kwargs)
173 result = self._process_executor.submit(
174 self.run_cloud_pickle,
177 result.add_done_callback(
178 lambda _: self.histogram.add_item(
184 def shutdown(self, wait=True) -> None:
185 logger.debug(f'Shutting down processpool executor {self.title}')
186 self._process_executor.shutdown(wait)
187 print(self.histogram)
189 def __getstate__(self):
190 state = self.__dict__.copy()
191 state['_process_executor'] = None
196 class RemoteWorkerRecord:
203 return hash((self.username, self.machine))
206 return f'{self.username}@{self.machine}'
214 worker: Optional[RemoteWorkerRecord]
215 username: Optional[str]
216 machine: Optional[str]
225 src_bundle: BundleDetails
226 is_cancelled: threading.Event
228 backup_bundles: Optional[List[BundleDetails]]
232 class RemoteExecutorStatus:
233 def __init__(self, total_worker_count: int) -> None:
234 self.worker_count = total_worker_count
235 self.known_workers: Set[RemoteWorkerRecord] = set()
236 self.start_per_bundle: Dict[str, float] = defaultdict(float)
237 self.end_per_bundle: Dict[str, float] = defaultdict(float)
238 self.finished_bundle_timings_per_worker: Dict[
242 self.in_flight_bundles_by_worker: Dict[
246 self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
247 self.finished_bundle_timings: List[float] = []
248 self.last_periodic_dump: Optional[float] = None
249 self.total_bundles_submitted = 0
251 # Protects reads and modification using self. Also used
252 # as a memory fence for modifications to bundle.
253 self.lock = threading.Lock()
255 def record_acquire_worker(
257 worker: RemoteWorkerRecord,
261 self.record_acquire_worker_already_locked(
266 def record_acquire_worker_already_locked(
268 worker: RemoteWorkerRecord,
271 assert self.lock.locked()
272 self.known_workers.add(worker)
273 self.start_per_bundle[uuid] = time.time()
274 x = self.in_flight_bundles_by_worker.get(worker, set())
276 self.in_flight_bundles_by_worker[worker] = x
278 def record_bundle_details(
280 details: BundleDetails) -> None:
282 self.record_bundle_details_already_locked(details)
284 def record_bundle_details_already_locked(
286 details: BundleDetails) -> None:
287 assert self.lock.locked()
288 self.bundle_details_by_uuid[details.uuid] = details
290 def record_release_worker_already_locked(
292 worker: RemoteWorkerRecord,
296 assert self.lock.locked()
298 self.end_per_bundle[uuid] = ts
299 self.in_flight_bundles_by_worker[worker].remove(uuid)
300 if not was_cancelled:
301 bundle_latency = ts - self.start_per_bundle[uuid]
302 x = self.finished_bundle_timings_per_worker.get(worker, list())
303 x.append(bundle_latency)
304 self.finished_bundle_timings_per_worker[worker] = x
305 self.finished_bundle_timings.append(bundle_latency)
307 def total_in_flight(self) -> int:
308 assert self.lock.locked()
310 for worker in self.known_workers:
311 total_in_flight += len(self.in_flight_bundles_by_worker[worker])
312 return total_in_flight
314 def total_idle(self) -> int:
315 assert self.lock.locked()
316 return self.worker_count - self.total_in_flight()
319 assert self.lock.locked()
321 total_finished = len(self.finished_bundle_timings)
322 total_in_flight = self.total_in_flight()
323 ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
325 if len(self.finished_bundle_timings) > 1:
326 qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
328 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, '
329 f'✅={total_finished}/{self.total_bundles_submitted}, '
330 f'💻n={total_in_flight}/{self.worker_count}\n'
334 f' ✅={total_finished}/{self.total_bundles_submitted}, '
335 f'💻n={total_in_flight}/{self.worker_count}\n'
338 for worker in self.known_workers:
339 ret += f' {fg("lightning yellow")}{worker.machine}{reset()}: '
340 timings = self.finished_bundle_timings_per_worker.get(worker, [])
344 qworker = numpy.quantile(timings, [0.5, 0.95])
345 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
349 ret += f' ...finished {count} total bundle(s) so far\n'
350 in_flight = len(self.in_flight_bundles_by_worker[worker])
352 ret += f' ...{in_flight} bundles currently in flight:\n'
353 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
354 details = self.bundle_details_by_uuid.get(
358 pid = str(details.pid) if details is not None else "TBD"
359 sec = ts - self.start_per_bundle[bundle_uuid]
360 ret += f' (pid={pid}): {bundle_uuid} for {sec:.1f}s so far '
361 if qworker is not None:
363 ret += f'{bg("red")}>💻p95{reset()} '
364 elif sec > qworker[0]:
365 ret += f'{fg("red")}>💻p50{reset()} '
367 if sec > qall[1] * 1.5:
368 ret += f'{bg("red")}!!!{reset()}'
369 if details is not None:
370 logger.debug(f'Flagging {details.uuid} for another backup')
371 details.super_slow = True
373 ret += f'{bg("red")}>∀p95{reset()} '
374 if details is not None:
375 logger.debug(f'Flagging {details.uuid} for a backup')
376 details.too_slow = True
378 ret += f'{fg("red")}>∀p50{reset()}'
382 def periodic_dump(self, total_bundles_submitted: int) -> None:
383 assert self.lock.locked()
384 self.total_bundles_submitted = total_bundles_submitted
387 self.last_periodic_dump is None
388 or ts - self.last_periodic_dump > 5.0
391 self.last_periodic_dump = ts
394 class RemoteWorkerSelectionPolicy(ABC):
395 def register_worker_pool(self, workers):
397 self.workers = workers
400 def is_worker_available(self) -> bool:
406 machine_to_avoid = None
407 ) -> Optional[RemoteWorkerRecord]:
411 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
412 def is_worker_available(self) -> bool:
413 for worker in self.workers:
420 machine_to_avoid = None
421 ) -> Optional[RemoteWorkerRecord]:
423 for worker in self.workers:
424 for x in range(0, worker.count):
425 for y in range(0, worker.weight):
426 grabbag.append(worker)
428 for _ in range(0, 5):
429 random.shuffle(grabbag)
431 if worker.machine != machine_to_avoid or _ > 2:
434 logger.debug(f'Selected worker {worker}')
436 logger.warning("Couldn't find a worker; go fish.")
440 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
441 def __init__(self) -> None:
444 def is_worker_available(self) -> bool:
445 for worker in self.workers:
452 machine_to_avoid: str = None
453 ) -> Optional[RemoteWorkerRecord]:
456 worker = self.workers[x]
460 if x >= len(self.workers):
463 logger.debug(f'Selected worker {worker}')
466 if x >= len(self.workers):
469 logger.warning("Couldn't find a worker; go fish.")
473 class RemoteExecutor(BaseExecutor):
475 workers: List[RemoteWorkerRecord],
476 policy: RemoteWorkerSelectionPolicy) -> None:
478 self.workers = workers
479 self.worker_count = 0
480 for worker in self.workers:
481 self.worker_count += worker.count
482 if self.worker_count <= 0:
483 msg = f"We need somewhere to schedule work; count was {self.worker_count}"
487 self.policy.register_worker_pool(self.workers)
488 self.cv = threading.Condition()
489 self._helper_executor = fut.ThreadPoolExecutor(
490 thread_name_prefix="remote_executor_helper",
491 max_workers=self.worker_count,
493 self.status = RemoteExecutorStatus(self.worker_count)
494 self.total_bundles_submitted = 0
496 f'Creating remote processpool with {self.worker_count} remote endpoints.'
499 def is_worker_available(self) -> bool:
500 return self.policy.is_worker_available()
504 machine_to_avoid: str = None
505 ) -> Optional[RemoteWorkerRecord]:
506 return self.policy.acquire_worker(machine_to_avoid)
508 def find_available_worker_or_block(
510 machine_to_avoid: str = None
511 ) -> RemoteWorkerRecord:
513 while not self.is_worker_available():
515 worker = self.acquire_worker(machine_to_avoid)
516 if worker is not None:
518 msg = "We should never reach this point in the code"
522 def release_worker(self, worker: RemoteWorkerRecord) -> None:
523 logger.debug(f'Released worker {worker}')
528 def heartbeat(self) -> None:
529 with self.status.lock:
530 # Regular progress report
531 self.status.periodic_dump(self.total_bundles_submitted)
533 # Look for bundles to reschedule
534 if len(self.status.finished_bundle_timings) > 7:
535 for worker, bundle_uuids in self.status.in_flight_bundles_by_worker.items():
536 for uuid in bundle_uuids:
537 bundle = self.status.bundle_details_by_uuid.get(uuid, None)
539 bundle is not None and
541 bundle.src_bundle is None and
542 config.config['executors_schedule_remote_backups']
544 self.consider_backup_for_bundle(bundle)
546 def consider_backup_for_bundle(self, bundle: BundleDetails) -> None:
547 assert self.status.lock.locked()
550 and len(bundle.backup_bundles) == 0 # one backup per
552 msg = f"*** Rescheduling {bundle.pid}/{bundle.uuid} ***"
554 self.schedule_backup_for_bundle(bundle)
558 and len(bundle.backup_bundles) < 2 # two backups in dire situations
559 and self.status.total_idle() > 4
561 msg = f"*** Rescheduling {bundle.pid}/{bundle.uuid} ***"
563 self.schedule_backup_for_bundle(bundle)
566 def check_if_cancelled(self, bundle: BundleDetails) -> bool:
567 with self.status.lock:
568 if bundle.is_cancelled.wait(timeout=0.0):
569 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
570 bundle.was_cancelled = True
574 def launch(self, bundle: BundleDetails) -> Any:
575 """Find a worker for bundle or block until one is available."""
576 self.adjust_task_count(+1)
578 hostname = bundle.hostname
581 # Try not to schedule a backup on the same host as the original.
582 if bundle.src_bundle is not None:
583 avoid_machine = bundle.src_bundle.machine
585 while worker is None:
586 worker = self.find_available_worker_or_block(avoid_machine)
587 bundle.worker = worker
588 machine = bundle.machine = worker.machine
589 username = bundle.username = worker.username
591 self.status.record_acquire_worker(worker, uuid)
592 logger.debug(f'Running bundle {uuid} on {worker}...')
594 # Before we do any work, make sure the bundle is still viable.
595 if self.check_if_cancelled(bundle):
597 return self.post_launch_work(bundle)
598 except Exception as e:
600 logger.info(f"{uuid}/{fname}: bundle seems to have failed?!")
601 if bundle.failure_count < config.config['executors_max_bundle_failures']:
602 return self.launch(bundle)
604 logger.info(f"{uuid}/{fname}: bundle is poison, giving up on it.")
607 # Send input to machine if it's not local.
608 if hostname not in machine:
609 cmd = f'{RSYNC} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
610 logger.info(f"{uuid}/{fname}: Copying work to {worker} via {cmd}")
614 cmd = (f'{SSH} {bundle.username}@{bundle.machine} '
615 f'"source py39-venv/bin/activate &&'
616 f' /home/scott/lib/python_modules/remote_worker.py'
617 f' --code_file {bundle.code_file} --result_file {bundle.result_file}"')
618 p = cmd_in_background(cmd, silent=True)
619 bundle.pid = pid = p.pid
620 logger.info(f"{uuid}/{fname}: Start training on {worker} via {cmd} (background pid {pid})")
625 except subprocess.TimeoutExpired:
628 # Both source and backup bundles can be cancelled by
629 # the other depending on which finishes first.
630 if self.check_if_cancelled(bundle):
635 f"{uuid}/{fname}: pid {pid} has finished its work normally."
640 return self.post_launch_work(bundle)
641 except Exception as e:
643 logger.info(f"{uuid}: Bundle seems to have failed?!")
644 if bundle.failure_count < config.config['executors_max_bundle_failures']:
645 return self.launch(bundle)
646 logger.info(f"{uuid}: Bundle is poison, giving up on it.")
649 def post_launch_work(self, bundle: BundleDetails) -> Any:
650 with self.status.lock:
651 is_original = bundle.src_bundle is None
652 was_cancelled = bundle.was_cancelled
653 username = bundle.username
654 machine = bundle.machine
655 result_file = bundle.result_file
656 code_file = bundle.code_file
660 # Whether original or backup, if we finished first we must
661 # fetch the results if the computation happened on a
663 bundle.end_ts = time.time()
664 if not was_cancelled:
665 assert bundle.machine is not None
666 if bundle.hostname not in bundle.machine:
667 cmd = f'{RSYNC} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
669 f"{uuid}/{fname}: Fetching results from {username}@{machine} via {cmd}"
673 except subprocess.CalledProcessError:
674 logger.critical(f'Failed to copy {username}@{machine}:{result_file}!')
675 run_silently(f'{SSH} {username}@{machine}'
676 f' "/bin/rm -f {code_file} {result_file}"')
677 dur = bundle.end_ts - bundle.start_ts
678 self.histogram.add_item(dur)
679 assert bundle.worker is not None
680 self.status.record_release_worker_already_locked(
686 # Only the original worker should unpickle the file contents
687 # though since it's the only one whose result matters. The
688 # original is also the only job that may delete result_file
689 # from disk. Note that the original may have been cancelled
690 # if one of the backups finished first; it still must read the
693 logger.debug(f"{uuid}/{fname}: Unpickling {result_file}.")
695 with open(f'{result_file}', 'rb') as rb:
696 serialized = rb.read()
697 result = cloudpickle.loads(serialized)
698 except Exception as e:
699 msg = f'Failed to load {result_file}'
701 bundle.failure_count += 1
702 self.release_worker(bundle.worker)
704 os.remove(f'{result_file}')
705 os.remove(f'{code_file}')
707 # Notify any backups that the original is done so they
708 # should stop ASAP. Do this whether or not we
709 # finished first since there could be more than one
711 if bundle.backup_bundles is not None:
712 for backup in bundle.backup_bundles:
714 f'{uuid}/{fname}: Notifying backup {backup.uuid} that it\'s cancelled'
716 backup.is_cancelled.set()
718 # This is a backup job.
720 # Backup results don't matter, they just need to leave the
721 # result file in the right place for their originals to
722 # read/unpickle later.
725 # Tell the original to stop if we finished first.
726 if not was_cancelled:
728 f'{uuid}/{fname}: Notifying original {bundle.src_bundle.uuid} that it\'s cancelled'
730 bundle.src_bundle.is_cancelled.set()
732 assert bundle.worker is not None
733 self.release_worker(bundle.worker)
734 self.adjust_task_count(-1)
737 def create_original_bundle(self, pickle, fname: str):
738 from string_utils import generate_uuid
739 uuid = generate_uuid(as_hex=True)
740 code_file = f'/tmp/{uuid}.code.bin'
741 result_file = f'/tmp/{uuid}.result.bin'
743 logger.debug(f'Writing pickled code to {code_file}')
744 with open(f'{code_file}', 'wb') as wb:
747 bundle = BundleDetails(
748 pickled_code = pickle,
754 hostname = platform.node(),
755 code_file = code_file,
756 result_file = result_file,
758 start_ts = time.time(),
763 is_cancelled = threading.Event(),
764 was_cancelled = False,
768 self.status.record_bundle_details(bundle)
769 logger.debug(f'{uuid}/{fname}: Created original bundle')
772 def create_backup_bundle(self, src_bundle: BundleDetails):
773 assert src_bundle.backup_bundles is not None
774 n = len(src_bundle.backup_bundles)
775 uuid = src_bundle.uuid + f'_backup#{n}'
777 backup_bundle = BundleDetails(
778 pickled_code = src_bundle.pickled_code,
780 fname = src_bundle.fname,
784 hostname = src_bundle.hostname,
785 code_file = src_bundle.code_file,
786 result_file = src_bundle.result_file,
788 start_ts = time.time(),
792 src_bundle = src_bundle,
793 is_cancelled = threading.Event(),
794 was_cancelled = False,
795 backup_bundles = None, # backup backups not allowed
798 src_bundle.backup_bundles.append(backup_bundle)
799 self.status.record_bundle_details_already_locked(backup_bundle)
800 logger.debug(f'{uuid}/{src_bundle.fname}: Created backup bundle')
803 def schedule_backup_for_bundle(self,
804 src_bundle: BundleDetails):
805 assert self.status.lock.locked()
806 backup_bundle = self.create_backup_bundle(src_bundle)
808 f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
810 self._helper_executor.submit(self.launch, backup_bundle)
812 # Results from backups don't matter; if they finish first
813 # they will move the result_file to this machine and let
814 # the original pick them up and unpickle them.
819 **kwargs) -> fut.Future:
820 pickle = make_cloud_pickle(function, *args, **kwargs)
821 bundle = self.create_original_bundle(pickle, function.__name__)
822 self.total_bundles_submitted += 1
823 return self._helper_executor.submit(self.launch, bundle)
825 def shutdown(self, wait=True) -> None:
826 self._helper_executor.shutdown(wait)
827 logging.debug(f'Shutting down RemoteExecutor {self.title}')
828 print(self.histogram)
832 class DefaultExecutors(object):
834 self.thread_executor: Optional[ThreadExecutor] = None
835 self.process_executor: Optional[ProcessExecutor] = None
836 self.remote_executor: Optional[RemoteExecutor] = None
838 def ping(self, host) -> bool:
839 command = ['ping', '-c', '1', host]
840 return subprocess.call(
842 stdout=subprocess.DEVNULL,
843 stderr=subprocess.DEVNULL,
846 def thread_pool(self) -> ThreadExecutor:
847 if self.thread_executor is None:
848 self.thread_executor = ThreadExecutor()
849 return self.thread_executor
851 def process_pool(self) -> ProcessExecutor:
852 if self.process_executor is None:
853 self.process_executor = ProcessExecutor()
854 return self.process_executor
856 def remote_pool(self) -> RemoteExecutor:
857 if self.remote_executor is None:
858 pool: List[RemoteWorkerRecord] = []
859 if self.ping('cheetah.house'):
863 machine = 'cheetah.house',
868 if self.ping('video.house'):
872 machine = 'video.house',
877 if self.ping('wannabe.house'):
881 machine = 'wannabe.house',
886 if self.ping('meerkat.cabin'):
890 machine = 'meerkat.cabin',
895 if self.ping('backup.house'):
899 machine = 'backup.house',
904 if self.ping('puma.cabin'):
908 machine = 'puma.cabin',
913 policy = WeightedRandomRemoteWorkerSelectionPolicy()
914 policy.register_worker_pool(pool)
915 self.remote_executor = RemoteExecutor(pool, policy)
916 return self.remote_executor