Stop using rsync in executors; this was a hack to work around some
[python_utils.git] / executors.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import concurrent.futures as fut
7 from collections import defaultdict
8 from dataclasses import dataclass
9 import logging
10 import numpy
11 import os
12 import platform
13 import random
14 import subprocess
15 import threading
16 import time
17 from typing import Any, Callable, Dict, List, Optional, Set
18 import warnings
19
20 import cloudpickle  # type: ignore
21 from overrides import overrides
22
23 from ansi import bg, fg, underline, reset
24 import argparse_utils
25 import config
26 from decorator_utils import singleton
27 from exec_utils import run_silently, cmd_in_background, cmd_with_timeout
28 import histogram as hist
29
30 logger = logging.getLogger(__name__)
31
32 parser = config.add_commandline_args(
33     f"Executors ({__file__})",
34     "Args related to processing executors."
35 )
36 parser.add_argument(
37     '--executors_threadpool_size',
38     type=int,
39     metavar='#THREADS',
40     help='Number of threads in the default threadpool, leave unset for default',
41     default=None
42 )
43 parser.add_argument(
44     '--executors_processpool_size',
45     type=int,
46     metavar='#PROCESSES',
47     help='Number of processes in the default processpool, leave unset for default',
48     default=None,
49 )
50 parser.add_argument(
51     '--executors_schedule_remote_backups',
52     default=True,
53     action=argparse_utils.ActionNoYes,
54     help='Should we schedule duplicative backup work if a remote bundle is slow',
55 )
56 parser.add_argument(
57     '--executors_max_bundle_failures',
58     type=int,
59     default=3,
60     metavar='#FAILURES',
61     help='Maximum number of failures before giving up on a bundle',
62 )
63
64 SSH = '/usr/bin/ssh -oForwardX11=no'
65 SCP = '/usr/bin/scp'
66
67 def make_cloud_pickle(fun, *args, **kwargs):
68     logger.debug(f"Making cloudpickled bundle at {fun.__name__}")
69     return cloudpickle.dumps((fun, args, kwargs))
70
71
72 class BaseExecutor(ABC):
73     def __init__(self, *, title=''):
74         self.title = title
75         self.task_count = 0
76         self.histogram = hist.SimpleHistogram(
77             hist.SimpleHistogram.n_evenly_spaced_buckets(
78                 int(0), int(500), 50
79             )
80         )
81
82     @abstractmethod
83     def submit(self,
84                function: Callable,
85                *args,
86                **kwargs) -> fut.Future:
87         pass
88
89     @abstractmethod
90     def shutdown(self,
91                  wait: bool = True) -> None:
92         pass
93
94     def adjust_task_count(self, delta: int) -> None:
95         self.task_count += delta
96         logger.debug(f'Executor current task count is {self.task_count}')
97
98
99 class ThreadExecutor(BaseExecutor):
100     def __init__(self,
101                  max_workers: Optional[int] = None):
102         super().__init__()
103         workers = None
104         if max_workers is not None:
105             workers = max_workers
106         elif 'executors_threadpool_size' in config.config:
107             workers = config.config['executors_threadpool_size']
108         logger.debug(f'Creating threadpool executor with {workers} workers')
109         self._thread_pool_executor = fut.ThreadPoolExecutor(
110             max_workers=workers,
111             thread_name_prefix="thread_executor_helper"
112         )
113
114     def run_local_bundle(self, fun, *args, **kwargs):
115         logger.debug(f"Running local bundle at {fun.__name__}")
116         start = time.time()
117         result = fun(*args, **kwargs)
118         end = time.time()
119         self.adjust_task_count(-1)
120         duration = end - start
121         logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
122         self.histogram.add_item(duration)
123         return result
124
125     @overrides
126     def submit(self,
127                function: Callable,
128                *args,
129                **kwargs) -> fut.Future:
130         self.adjust_task_count(+1)
131         newargs = []
132         newargs.append(function)
133         for arg in args:
134             newargs.append(arg)
135         return self._thread_pool_executor.submit(
136             self.run_local_bundle,
137             *newargs,
138             **kwargs)
139
140     @overrides
141     def shutdown(self,
142                  wait = True) -> None:
143         logger.debug(f'Shutting down threadpool executor {self.title}')
144         print(self.histogram)
145         self._thread_pool_executor.shutdown(wait)
146
147
148 class ProcessExecutor(BaseExecutor):
149     def __init__(self,
150                  max_workers=None):
151         super().__init__()
152         workers = None
153         if max_workers is not None:
154             workers = max_workers
155         elif 'executors_processpool_size' in config.config:
156             workers = config.config['executors_processpool_size']
157         logger.debug(f'Creating processpool executor with {workers} workers.')
158         self._process_executor = fut.ProcessPoolExecutor(
159             max_workers=workers,
160         )
161
162     def run_cloud_pickle(self, pickle):
163         fun, args, kwargs = cloudpickle.loads(pickle)
164         logger.debug(f"Running pickled bundle at {fun.__name__}")
165         result = fun(*args, **kwargs)
166         self.adjust_task_count(-1)
167         return result
168
169     @overrides
170     def submit(self,
171                function: Callable,
172                *args,
173                **kwargs) -> fut.Future:
174         start = time.time()
175         self.adjust_task_count(+1)
176         pickle = make_cloud_pickle(function, *args, **kwargs)
177         result = self._process_executor.submit(
178             self.run_cloud_pickle,
179             pickle
180         )
181         result.add_done_callback(
182             lambda _: self.histogram.add_item(
183                 time.time() - start
184             )
185         )
186         return result
187
188     @overrides
189     def shutdown(self, wait=True) -> None:
190         logger.debug(f'Shutting down processpool executor {self.title}')
191         self._process_executor.shutdown(wait)
192         print(self.histogram)
193
194     def __getstate__(self):
195         state = self.__dict__.copy()
196         state['_process_executor'] = None
197         return state
198
199
200 class RemoteExecutorException(Exception):
201     """Thrown when a bundle cannot be executed despite several retries."""
202     pass
203
204
205 @dataclass
206 class RemoteWorkerRecord:
207     username: str
208     machine: str
209     weight: int
210     count: int
211
212     def __hash__(self):
213         return hash((self.username, self.machine))
214
215     def __repr__(self):
216         return f'{self.username}@{self.machine}'
217
218
219 @dataclass
220 class BundleDetails:
221     pickled_code: bytes
222     uuid: str
223     fname: str
224     worker: Optional[RemoteWorkerRecord]
225     username: Optional[str]
226     machine: Optional[str]
227     hostname: str
228     code_file: str
229     result_file: str
230     pid: int
231     start_ts: float
232     end_ts: float
233     slower_than_local_p95: bool
234     slower_than_global_p95: bool
235     src_bundle: BundleDetails
236     is_cancelled: threading.Event
237     was_cancelled: bool
238     backup_bundles: Optional[List[BundleDetails]]
239     failure_count: int
240
241     def __repr__(self):
242         uuid = self.uuid
243         if uuid[-9:-2] == '_backup':
244             uuid = uuid[:-9]
245             suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}'
246         else:
247             suffix = uuid[-6:]
248
249         colorz = [
250             fg('violet red'),
251             fg('red'),
252             fg('orange'),
253             fg('peach orange'),
254             fg('yellow'),
255             fg('marigold yellow'),
256             fg('green yellow'),
257             fg('tea green'),
258             fg('cornflower blue'),
259             fg('turquoise blue'),
260             fg('tropical blue'),
261             fg('lavender purple'),
262             fg('medium purple'),
263         ]
264         c = colorz[int(uuid[-2:], 16) % len(colorz)]
265         fname = self.fname if self.fname is not None else 'nofname'
266         machine = self.machine if self.machine is not None else 'nomachine'
267         return f'{c}{suffix}/{fname}/{machine}{reset()}'
268
269
270 class RemoteExecutorStatus:
271     def __init__(self, total_worker_count: int) -> None:
272         self.worker_count = total_worker_count
273         self.known_workers: Set[RemoteWorkerRecord] = set()
274         self.start_per_bundle: Dict[str, float] = defaultdict(float)
275         self.end_per_bundle: Dict[str, float] = defaultdict(float)
276         self.finished_bundle_timings_per_worker: Dict[
277             RemoteWorkerRecord,
278             List[float]
279         ] = {}
280         self.in_flight_bundles_by_worker: Dict[
281             RemoteWorkerRecord,
282             Set[str]
283         ] = {}
284         self.bundle_details_by_uuid: Dict[str, BundleDetails] = {}
285         self.finished_bundle_timings: List[float] = []
286         self.last_periodic_dump: Optional[float] = None
287         self.total_bundles_submitted = 0
288
289         # Protects reads and modification using self.  Also used
290         # as a memory fence for modifications to bundle.
291         self.lock = threading.Lock()
292
293     def record_acquire_worker(
294             self,
295             worker: RemoteWorkerRecord,
296             uuid: str
297     ) -> None:
298         with self.lock:
299             self.record_acquire_worker_already_locked(
300                 worker,
301                 uuid
302             )
303
304     def record_acquire_worker_already_locked(
305             self,
306             worker: RemoteWorkerRecord,
307             uuid: str
308     ) -> None:
309         assert self.lock.locked()
310         self.known_workers.add(worker)
311         self.start_per_bundle[uuid] = None
312         x = self.in_flight_bundles_by_worker.get(worker, set())
313         x.add(uuid)
314         self.in_flight_bundles_by_worker[worker] = x
315
316     def record_bundle_details(
317             self,
318             details: BundleDetails) -> None:
319         with self.lock:
320             self.record_bundle_details_already_locked(details)
321
322     def record_bundle_details_already_locked(
323             self,
324             details: BundleDetails) -> None:
325         assert self.lock.locked()
326         self.bundle_details_by_uuid[details.uuid] = details
327
328     def record_release_worker(
329             self,
330             worker: RemoteWorkerRecord,
331             uuid: str,
332             was_cancelled: bool,
333     ) -> None:
334         with self.lock:
335             self.record_release_worker_already_locked(
336                 worker, uuid, was_cancelled
337             )
338
339     def record_release_worker_already_locked(
340             self,
341             worker: RemoteWorkerRecord,
342             uuid: str,
343             was_cancelled: bool,
344     ) -> None:
345         assert self.lock.locked()
346         ts = time.time()
347         self.end_per_bundle[uuid] = ts
348         self.in_flight_bundles_by_worker[worker].remove(uuid)
349         if not was_cancelled:
350             bundle_latency = ts - self.start_per_bundle[uuid]
351             x = self.finished_bundle_timings_per_worker.get(worker, list())
352             x.append(bundle_latency)
353             self.finished_bundle_timings_per_worker[worker] = x
354             self.finished_bundle_timings.append(bundle_latency)
355
356     def record_processing_began(self, uuid: str):
357         with self.lock:
358             self.start_per_bundle[uuid] = time.time()
359
360     def total_in_flight(self) -> int:
361         assert self.lock.locked()
362         total_in_flight = 0
363         for worker in self.known_workers:
364             total_in_flight += len(self.in_flight_bundles_by_worker[worker])
365         return total_in_flight
366
367     def total_idle(self) -> int:
368         assert self.lock.locked()
369         return self.worker_count - self.total_in_flight()
370
371     def __repr__(self):
372         assert self.lock.locked()
373         ts = time.time()
374         total_finished = len(self.finished_bundle_timings)
375         total_in_flight = self.total_in_flight()
376         ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: '
377         qall = None
378         if len(self.finished_bundle_timings) > 1:
379             qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95])
380             ret += (
381                 f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, '
382                 f'✅={total_finished}/{self.total_bundles_submitted}, '
383                 f'💻n={total_in_flight}/{self.worker_count}\n'
384             )
385         else:
386             ret += (
387                 f' ✅={total_finished}/{self.total_bundles_submitted}, '
388                 f'💻n={total_in_flight}/{self.worker_count}\n'
389             )
390
391         for worker in self.known_workers:
392             ret += f'  {fg("lightning yellow")}{worker.machine}{reset()}: '
393             timings = self.finished_bundle_timings_per_worker.get(worker, [])
394             count = len(timings)
395             qworker = None
396             if count > 1:
397                 qworker = numpy.quantile(timings, [0.5, 0.95])
398                 ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n'
399             else:
400                 ret += '\n'
401             if count > 0:
402                 ret += f'    ...finished {count} total bundle(s) so far\n'
403             in_flight = len(self.in_flight_bundles_by_worker[worker])
404             if in_flight > 0:
405                 ret += f'    ...{in_flight} bundles currently in flight:\n'
406                 for bundle_uuid in self.in_flight_bundles_by_worker[worker]:
407                     details = self.bundle_details_by_uuid.get(
408                         bundle_uuid,
409                         None
410                     )
411                     pid = str(details.pid) if (details and details.pid != 0) else "TBD"
412                     if self.start_per_bundle[bundle_uuid] is not None:
413                         sec = ts - self.start_per_bundle[bundle_uuid]
414                         ret += f'       (pid={pid}): {details} for {sec:.1f}s so far '
415                     else:
416                         ret += f'       {details} setting up / copying data...'
417                         sec = 0.0
418
419                     if qworker is not None:
420                         if sec > qworker[1]:
421                             ret += f'{bg("red")}>💻p95{reset()} '
422                             if details is not None:
423                                 details.slower_than_local_p95 = True
424                         else:
425                             if details is not None:
426                                 details.slower_than_local_p95 = False
427
428                     if qall is not None:
429                         if sec > qall[1]:
430                             ret += f'{bg("red")}>∀p95{reset()} '
431                             if details is not None:
432                                 details.slower_than_global_p95 = True
433                         else:
434                             details.slower_than_global_p95 = False
435                     ret += '\n'
436         return ret
437
438     def periodic_dump(self, total_bundles_submitted: int) -> None:
439         assert self.lock.locked()
440         self.total_bundles_submitted = total_bundles_submitted
441         ts = time.time()
442         if (
443                 self.last_periodic_dump is None
444                 or ts - self.last_periodic_dump > 5.0
445         ):
446             print(self)
447             self.last_periodic_dump = ts
448
449
450 class RemoteWorkerSelectionPolicy(ABC):
451     def register_worker_pool(self, workers):
452         self.workers = workers
453
454     @abstractmethod
455     def is_worker_available(self) -> bool:
456         pass
457
458     @abstractmethod
459     def acquire_worker(
460             self,
461             machine_to_avoid = None
462     ) -> Optional[RemoteWorkerRecord]:
463         pass
464
465
466 class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
467     @overrides
468     def is_worker_available(self) -> bool:
469         for worker in self.workers:
470             if worker.count > 0:
471                 return True
472         return False
473
474     @overrides
475     def acquire_worker(
476             self,
477             machine_to_avoid = None
478     ) -> Optional[RemoteWorkerRecord]:
479         grabbag = []
480         for worker in self.workers:
481             for x in range(0, worker.count):
482                 for y in range(0, worker.weight):
483                     grabbag.append(worker)
484
485         for _ in range(0, 5):
486             random.shuffle(grabbag)
487             worker = grabbag[0]
488             if worker.machine != machine_to_avoid or _ > 2:
489                 if worker.count > 0:
490                     worker.count -= 1
491                     logger.debug(f'Selected worker {worker}')
492                     return worker
493         msg = 'Unexpectedly could not find a worker, retrying...'
494         logger.warning(msg)
495         return None
496
497
498 class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
499     def __init__(self) -> None:
500         self.index = 0
501
502     @overrides
503     def is_worker_available(self) -> bool:
504         for worker in self.workers:
505             if worker.count > 0:
506                 return True
507         return False
508
509     @overrides
510     def acquire_worker(
511             self,
512             machine_to_avoid: str = None
513     ) -> Optional[RemoteWorkerRecord]:
514         x = self.index
515         while True:
516             worker = self.workers[x]
517             if worker.count > 0:
518                 worker.count -= 1
519                 x += 1
520                 if x >= len(self.workers):
521                     x = 0
522                 self.index = x
523                 logger.debug(f'Selected worker {worker}')
524                 return worker
525             x += 1
526             if x >= len(self.workers):
527                 x = 0
528             if x == self.index:
529                 msg = 'Unexpectedly could not find a worker, retrying...'
530                 logger.warning(msg)
531                 return None
532
533
534 class RemoteExecutor(BaseExecutor):
535     def __init__(self,
536                  workers: List[RemoteWorkerRecord],
537                  policy: RemoteWorkerSelectionPolicy) -> None:
538         super().__init__()
539         self.workers = workers
540         self.policy = policy
541         self.worker_count = 0
542         for worker in self.workers:
543             self.worker_count += worker.count
544         if self.worker_count <= 0:
545             msg = f"We need somewhere to schedule work; count was {self.worker_count}"
546             logger.critical(msg)
547             raise RemoteExecutorException(msg)
548         self.policy.register_worker_pool(self.workers)
549         self.cv = threading.Condition()
550         logger.debug(f'Creating {self.worker_count} local threads, one per remote worker.')
551         self._helper_executor = fut.ThreadPoolExecutor(
552             thread_name_prefix="remote_executor_helper",
553             max_workers=self.worker_count,
554         )
555         self.status = RemoteExecutorStatus(self.worker_count)
556         self.total_bundles_submitted = 0
557         self.backup_lock = threading.Lock()
558         self.last_backup = None
559
560     def is_worker_available(self) -> bool:
561         return self.policy.is_worker_available()
562
563     def acquire_worker(
564             self,
565             machine_to_avoid: str = None
566     ) -> Optional[RemoteWorkerRecord]:
567         return self.policy.acquire_worker(machine_to_avoid)
568
569     def find_available_worker_or_block(
570             self,
571             machine_to_avoid: str = None
572     ) -> RemoteWorkerRecord:
573         with self.cv:
574             while not self.is_worker_available():
575                 self.cv.wait()
576             worker = self.acquire_worker(machine_to_avoid)
577             if worker is not None:
578                 return worker
579         msg = "We should never reach this point in the code"
580         logger.critical(msg)
581         raise Exception(msg)
582
583     def release_worker(self, worker: RemoteWorkerRecord) -> None:
584         logger.debug(f'Released worker {worker}')
585         with self.cv:
586             worker.count += 1
587             self.cv.notify()
588
589     def heartbeat(self) -> None:
590         with self.status.lock:
591             # Regular progress report
592             self.status.periodic_dump(self.total_bundles_submitted)
593
594             # Look for bundles to reschedule.
595             num_done = len(self.status.finished_bundle_timings)
596             num_idle_workers = self.worker_count - self.task_count
597             now = time.time()
598             if (
599                     config.config['executors_schedule_remote_backups']
600                     and num_done > 2
601                     and num_idle_workers > 1
602                     and (self.last_backup is None or (now - self.last_backup > 1.0))
603                     and self.backup_lock.acquire(blocking=False)
604             ):
605                 try:
606                     assert self.backup_lock.locked()
607
608                     bundle_to_backup = None
609                     best_score = None
610                     for worker, bundle_uuids in self.status.in_flight_bundles_by_worker.items():
611                         # Prefer to schedule backups of bundles on slower machines.
612                         base_score = 0
613                         for record in self.workers:
614                             if worker.machine == record.machine:
615                                 base_score = float(record.weight)
616                                 base_score = 1.0 / base_score
617                                 base_score *= 200.0
618                                 base_score = int(base_score)
619                                 break
620
621                         for uuid in bundle_uuids:
622                             bundle = self.status.bundle_details_by_uuid.get(uuid, None)
623                             if (
624                                     bundle is not None
625                                     and bundle.src_bundle is None
626                                     and bundle.backup_bundles is not None
627                             ):
628                                 score = base_score
629
630                                 # Schedule backups of bundles running longer; especially those
631                                 # that are unexpectedly slow.
632                                 start_ts = self.status.start_per_bundle[uuid]
633                                 if start_ts is not None:
634                                     runtime = now - start_ts
635                                     score += runtime
636                                     logger.debug(f'score[{bundle}] => {score}  # latency boost')
637
638                                     if bundle.slower_than_local_p95:
639                                         score += runtime / 2
640                                         logger.debug(f'score[{bundle}] => {score}  # >worker p95')
641
642                                     if bundle.slower_than_global_p95:
643                                         score += runtime / 2
644                                         logger.debug(f'score[{bundle}] => {score}  # >global p95')
645
646                                 # Prefer backups of bundles that don't have backups already.
647                                 backup_count = len(bundle.backup_bundles)
648                                 if backup_count == 0:
649                                     score *= 2
650                                 elif backup_count == 1:
651                                     score /= 2
652                                 elif backup_count == 2:
653                                     score /= 8
654                                 else:
655                                     score = 0
656                                 logger.debug(f'score[{bundle}] => {score}  # {backup_count} dup backup factor')
657
658                                 if (
659                                         score != 0
660                                         and (best_score is None or score > best_score)
661                                 ):
662                                     bundle_to_backup = bundle
663                                     assert bundle is not None
664                                     assert bundle.backup_bundles is not None
665                                     assert bundle.src_bundle is None
666                                     best_score = score
667
668                     if bundle_to_backup is not None:
669                         self.last_backup = now
670                         logger.info(f'=====> SCHEDULING BACKUP {bundle_to_backup} (score={best_score:.1f}) <=====')
671                         self.schedule_backup_for_bundle(bundle_to_backup)
672                 finally:
673                     self.backup_lock.release()
674
675     def check_if_cancelled(self, bundle: BundleDetails) -> bool:
676         with self.status.lock:
677             if bundle.is_cancelled.wait(timeout=0.0):
678                 logger.debug(f'Bundle {bundle.uuid} is cancelled, bail out.')
679                 bundle.was_cancelled = True
680                 return True
681         return False
682
683     def launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any:
684         """Find a worker for bundle or block until one is available."""
685         self.adjust_task_count(+1)
686         uuid = bundle.uuid
687         hostname = bundle.hostname
688         avoid_machine = override_avoid_machine
689         is_original = bundle.src_bundle is None
690
691         # Try not to schedule a backup on the same host as the original.
692         if avoid_machine is None and bundle.src_bundle is not None:
693             avoid_machine = bundle.src_bundle.machine
694         worker = None
695         while worker is None:
696             worker = self.find_available_worker_or_block(avoid_machine)
697
698         # Ok, found a worker.
699         bundle.worker = worker
700         machine = bundle.machine = worker.machine
701         username = bundle.username = worker.username
702
703         self.status.record_acquire_worker(worker, uuid)
704         logger.debug(f'{bundle}: Running bundle on {worker}...')
705
706         # Before we do any work, make sure the bundle is still viable.
707         if self.check_if_cancelled(bundle):
708             try:
709                 return self.post_launch_work(bundle)
710             except Exception as e:
711                 logger.exception(e)
712                 logger.error(
713                     f'{bundle}: bundle says it\'s cancelled upfront but no results?!'
714                 )
715                 assert bundle.worker is not None
716                 self.status.record_release_worker(
717                     bundle.worker,
718                     bundle.uuid,
719                     True,
720                 )
721                 self.release_worker(bundle.worker)
722                 self.adjust_task_count(-1)
723                 if is_original:
724                     # Weird.  We are the original owner of this
725                     # bundle.  For it to have been cancelled, a backup
726                     # must have already started and completed before
727                     # we even for started.  Moreover, the backup says
728                     # it is done but we can't find the results it
729                     # should have copied over.  Reschedule the whole
730                     # thing.
731                     return self.emergency_retry_nasty_bundle(bundle)
732                 else:
733                     # Expected(?).  We're a backup and our bundle is
734                     # cancelled before we even got started.  Something
735                     # went bad in post_launch_work (I acutually don't
736                     # see what?) but probably not worth worrying
737                     # about.
738                     return None
739
740         # Send input code / data to worker machine if it's not local.
741         if hostname not in machine:
742             try:
743                 cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
744                 start_ts = time.time()
745                 logger.info(f"{bundle}: Copying work to {worker} via {cmd}.")
746                 run_silently(cmd)
747                 xfer_latency = time.time() - start_ts
748                 logger.info(f"{bundle}: Copying done to {worker} in {xfer_latency:.1f}s.")
749             except Exception as e:
750                 assert bundle.worker is not None
751                 self.status.record_release_worker(
752                     bundle.worker,
753                     bundle.uuid,
754                     True,
755                 )
756                 self.release_worker(bundle.worker)
757                 self.adjust_task_count(-1)
758                 if is_original:
759                     # Weird.  We tried to copy the code to the worker and it failed...
760                     # And we're the original bundle.  We have to retry.
761                     logger.exception(e)
762                     logger.error(
763                         f'{bundle}: Failed to send instructions to the worker machine?! ' +
764                         'This is not expected; we\'re the original bundle so this shouldn\'t ' +
765                         'be a race condition.  Attempting an emergency retry...'
766                     )
767                     return self.emergency_retry_nasty_bundle(bundle)
768                 else:
769                     # This is actually expected; we're a backup.
770                     # There's a race condition where someone else
771                     # already finished the work and removed the source
772                     # code file before we could copy it.  No biggie.
773                     msg = f'{bundle}: Failed to send instructions to the worker machine... '
774                     msg += 'We\'re a backup and this may be caused by the original (or some '
775                     msg += 'other backup) already finishing this work.  Ignoring this.'
776                     logger.warning(msg)
777                     return None
778
779         # Kick off the work.  Note that if this fails we let
780         # wait_for_process deal with it.
781         self.status.record_processing_began(uuid)
782         cmd = (f'{SSH} {bundle.username}@{bundle.machine} '
783                f'"source py38-venv/bin/activate &&'
784                f' /home/scott/lib/python_modules/remote_worker.py'
785                f' --code_file {bundle.code_file} --result_file {bundle.result_file}"')
786         logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...')
787         p = cmd_in_background(cmd, silent=True)
788         bundle.pid = pid = p.pid
789         logger.debug(f'{bundle}: Local ssh process pid={pid}; remote worker is {machine}.')
790         return self.wait_for_process(p, bundle, 0)
791
792     def wait_for_process(self, p: subprocess.Popen, bundle: BundleDetails, depth: int) -> Any:
793         machine = bundle.machine
794         pid = p.pid
795         if depth > 3:
796             logger.error(
797                 f"I've gotten repeated errors waiting on this bundle; giving up on pid={pid}."
798             )
799             p.terminate()
800             self.status.record_release_worker(
801                 bundle.worker,
802                 bundle.uuid,
803                 True,
804             )
805             self.release_worker(bundle.worker)
806             self.adjust_task_count(-1)
807             return self.emergency_retry_nasty_bundle(bundle)
808
809         # Spin until either the ssh job we scheduled finishes the
810         # bundle or some backup worker signals that they finished it
811         # before we could.
812         while True:
813             try:
814                 p.wait(timeout=0.25)
815             except subprocess.TimeoutExpired:
816                 self.heartbeat()
817                 if self.check_if_cancelled(bundle):
818                     logger.info(
819                         f'{bundle}: another worker finished bundle, checking it out...'
820                     )
821                     break
822             else:
823                 logger.info(
824                     f"{bundle}: pid {pid} ({machine}) our ssh finished, checking it out..."
825                 )
826                 p = None
827                 break
828
829         # If we get here we believe the bundle is done; either the ssh
830         # subprocess finished (hopefully successfully) or we noticed
831         # that some other worker seems to have completed the bundle
832         # and we're bailing out.
833         try:
834             ret = self.post_launch_work(bundle)
835             if ret is not None and p is not None:
836                 p.terminate()
837             return ret
838
839         # Something went wrong; e.g. we could not copy the results
840         # back, cleanup after ourselves on the remote machine, or
841         # unpickle the results we got from the remove machine.  If we
842         # still have an active ssh subprocess, keep waiting on it.
843         # Otherwise, time for an emergency reschedule.
844         except Exception as e:
845             logger.exception(e)
846             logger.error(f'{bundle}: Something unexpected just happened...')
847             if p is not None:
848                 msg = f"{bundle}: Failed to wrap up \"done\" bundle, re-waiting on active ssh."
849                 logger.warning(msg)
850                 return self.wait_for_process(p, bundle, depth + 1)
851             else:
852                 self.status.record_release_worker(
853                     bundle.worker,
854                     bundle.uuid,
855                     True,
856                 )
857                 self.release_worker(bundle.worker)
858                 self.adjust_task_count(-1)
859                 return self.emergency_retry_nasty_bundle(bundle)
860
861     def post_launch_work(self, bundle: BundleDetails) -> Any:
862         with self.status.lock:
863             is_original = bundle.src_bundle is None
864             was_cancelled = bundle.was_cancelled
865             username = bundle.username
866             machine = bundle.machine
867             result_file = bundle.result_file
868             code_file = bundle.code_file
869
870             # Whether original or backup, if we finished first we must
871             # fetch the results if the computation happened on a
872             # remote machine.
873             bundle.end_ts = time.time()
874             if not was_cancelled:
875                 assert bundle.machine is not None
876                 if bundle.hostname not in bundle.machine:
877                     cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
878                     logger.info(
879                         f"{bundle}: Fetching results from {username}@{machine} via {cmd}"
880                     )
881
882                     # If either of these throw they are handled in
883                     # wait_for_process.
884                     run_silently(cmd)
885                     run_silently(f'{SSH} {username}@{machine}'
886                                  f' "/bin/rm -f {code_file} {result_file}"')
887                 dur = bundle.end_ts - bundle.start_ts
888                 self.histogram.add_item(dur)
889
890         # Only the original worker should unpickle the file contents
891         # though since it's the only one whose result matters.  The
892         # original is also the only job that may delete result_file
893         # from disk.  Note that the original may have been cancelled
894         # if one of the backups finished first; it still must read the
895         # result from disk.
896         if is_original:
897             logger.debug(f"{bundle}: Unpickling {result_file}.")
898             try:
899                 with open(result_file, 'rb') as rb:
900                     serialized = rb.read()
901                 result = cloudpickle.loads(serialized)
902             except Exception as e:
903                 msg = f'Failed to load {result_file}, this is bad news.'
904                 logger.critical(msg)
905                 self.status.record_release_worker(
906                     bundle.worker,
907                     bundle.uuid,
908                     True,
909                 )
910                 self.release_worker(bundle.worker)
911
912                 # Re-raise the exception; the code in wait_for_process may
913                 # decide to emergency_retry_nasty_bundle here.
914                 raise Exception(e)
915
916             logger.debug(
917                 f'Removing local (master) {code_file} and {result_file}.'
918             )
919             os.remove(f'{result_file}')
920             os.remove(f'{code_file}')
921
922             # Notify any backups that the original is done so they
923             # should stop ASAP.  Do this whether or not we
924             # finished first since there could be more than one
925             # backup.
926             if bundle.backup_bundles is not None:
927                 for backup in bundle.backup_bundles:
928                     logger.debug(
929                         f'{bundle}: Notifying backup {backup.uuid} that it\'s cancelled'
930                     )
931                     backup.is_cancelled.set()
932
933         # This is a backup job and, by now, we have already fetched
934         # the bundle results.
935         else:
936             # Backup results don't matter, they just need to leave the
937             # result file in the right place for their originals to
938             # read/unpickle later.
939             result = None
940
941             # Tell the original to stop if we finished first.
942             if not was_cancelled:
943                 logger.debug(
944                     f'{bundle}: Notifying original {bundle.src_bundle.uuid} we beat them to it.'
945                 )
946                 bundle.src_bundle.is_cancelled.set()
947
948         assert bundle.worker is not None
949         self.status.record_release_worker(
950             bundle.worker,
951             bundle.uuid,
952             was_cancelled,
953         )
954         self.release_worker(bundle.worker)
955         self.adjust_task_count(-1)
956         return result
957
958     def create_original_bundle(self, pickle, fname: str):
959         from string_utils import generate_uuid
960         uuid = generate_uuid(omit_dashes=True)
961         code_file = f'/tmp/{uuid}.code.bin'
962         result_file = f'/tmp/{uuid}.result.bin'
963
964         logger.debug(f'Writing pickled code to {code_file}')
965         with open(f'{code_file}', 'wb') as wb:
966             wb.write(pickle)
967
968         bundle = BundleDetails(
969             pickled_code = pickle,
970             uuid = uuid,
971             fname = fname,
972             worker = None,
973             username = None,
974             machine = None,
975             hostname = platform.node(),
976             code_file = code_file,
977             result_file = result_file,
978             pid = 0,
979             start_ts = time.time(),
980             end_ts = 0.0,
981             slower_than_local_p95 = False,
982             slower_than_global_p95 = False,
983             src_bundle = None,
984             is_cancelled = threading.Event(),
985             was_cancelled = False,
986             backup_bundles = [],
987             failure_count = 0,
988         )
989         self.status.record_bundle_details(bundle)
990         logger.debug(f'{bundle}: Created an original bundle')
991         return bundle
992
993     def create_backup_bundle(self, src_bundle: BundleDetails):
994         assert src_bundle.backup_bundles is not None
995         n = len(src_bundle.backup_bundles)
996         uuid = src_bundle.uuid + f'_backup#{n}'
997
998         backup_bundle = BundleDetails(
999             pickled_code = src_bundle.pickled_code,
1000             uuid = uuid,
1001             fname = src_bundle.fname,
1002             worker = None,
1003             username = None,
1004             machine = None,
1005             hostname = src_bundle.hostname,
1006             code_file = src_bundle.code_file,
1007             result_file = src_bundle.result_file,
1008             pid = 0,
1009             start_ts = time.time(),
1010             end_ts = 0.0,
1011             slower_than_local_p95 = False,
1012             slower_than_global_p95 = False,
1013             src_bundle = src_bundle,
1014             is_cancelled = threading.Event(),
1015             was_cancelled = False,
1016             backup_bundles = None,    # backup backups not allowed
1017             failure_count = 0,
1018         )
1019         src_bundle.backup_bundles.append(backup_bundle)
1020         self.status.record_bundle_details_already_locked(backup_bundle)
1021         logger.debug(f'{backup_bundle}: Created a backup bundle')
1022         return backup_bundle
1023
1024     def schedule_backup_for_bundle(self,
1025                                    src_bundle: BundleDetails):
1026         assert self.status.lock.locked()
1027         assert src_bundle is not None
1028         backup_bundle = self.create_backup_bundle(src_bundle)
1029         logger.debug(
1030             f'{backup_bundle.uuid}/{backup_bundle.fname}: Scheduling backup for execution...'
1031         )
1032         self._helper_executor.submit(self.launch, backup_bundle)
1033
1034         # Results from backups don't matter; if they finish first
1035         # they will move the result_file to this machine and let
1036         # the original pick them up and unpickle them.
1037
1038     def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> fut.Future:
1039         is_original = bundle.src_bundle is None
1040         bundle.worker = None
1041         avoid_last_machine = bundle.machine
1042         bundle.machine = None
1043         bundle.username = None
1044         bundle.failure_count += 1
1045         if is_original:
1046             retry_limit = 3
1047         else:
1048             retry_limit = 2
1049
1050         if bundle.failure_count > retry_limit:
1051             logger.error(
1052                 f'{bundle}: Tried this bundle too many times already ({retry_limit}x); giving up.'
1053             )
1054             if is_original:
1055                 raise RemoteExecutorException(
1056                     f'{bundle}: This bundle can\'t be completed despite several backups and retries'
1057                 )
1058             else:
1059                 logger.error(f'{bundle}: At least it\'s only a backup; better luck with the others.')
1060             return None
1061         else:
1062             msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<'
1063             logger.warning(msg)
1064             warnings.warn(msg)
1065             return self.launch(bundle, avoid_last_machine)
1066
1067     @overrides
1068     def submit(self,
1069                function: Callable,
1070                *args,
1071                **kwargs) -> fut.Future:
1072         pickle = make_cloud_pickle(function, *args, **kwargs)
1073         bundle = self.create_original_bundle(pickle, function.__name__)
1074         self.total_bundles_submitted += 1
1075         return self._helper_executor.submit(self.launch, bundle)
1076
1077     @overrides
1078     def shutdown(self, wait=True) -> None:
1079         self._helper_executor.shutdown(wait)
1080         logging.debug(f'Shutting down RemoteExecutor {self.title}')
1081         print(self.histogram)
1082
1083
1084 @singleton
1085 class DefaultExecutors(object):
1086     def __init__(self):
1087         self.thread_executor: Optional[ThreadExecutor] = None
1088         self.process_executor: Optional[ProcessExecutor] = None
1089         self.remote_executor: Optional[RemoteExecutor] = None
1090
1091     def ping(self, host) -> bool:
1092         logger.debug(f'RUN> ping -c 1 {host}')
1093         try:
1094             x = cmd_with_timeout(
1095                 f'ping -c 1 {host} >/dev/null 2>/dev/null',
1096                 timeout_seconds=1.0
1097             )
1098             return x == 0
1099         except Exception:
1100             return False
1101
1102     def thread_pool(self) -> ThreadExecutor:
1103         if self.thread_executor is None:
1104             self.thread_executor = ThreadExecutor()
1105         return self.thread_executor
1106
1107     def process_pool(self) -> ProcessExecutor:
1108         if self.process_executor is None:
1109             self.process_executor = ProcessExecutor()
1110         return self.process_executor
1111
1112     def remote_pool(self) -> RemoteExecutor:
1113         if self.remote_executor is None:
1114             logger.info('Looking for some helper machines...')
1115             pool: List[RemoteWorkerRecord] = []
1116             if self.ping('cheetah.house'):
1117                 logger.info('Found cheetah.house')
1118                 pool.append(
1119                     RemoteWorkerRecord(
1120                         username = 'scott',
1121                         machine = 'cheetah.house',
1122                         weight = 25,
1123                         count = 6,
1124                     ),
1125                 )
1126             if self.ping('meerkat.cabin'):
1127                 logger.info('Found meerkat.cabin')
1128                 pool.append(
1129                     RemoteWorkerRecord(
1130                         username = 'scott',
1131                         machine = 'meerkat.cabin',
1132                         weight = 5,
1133                         count = 2,
1134                     ),
1135                 )
1136             # if self.ping('kiosk.house'):
1137             #     logger.info('Found kiosk.house')
1138             #     pool.append(
1139             #         RemoteWorkerRecord(
1140             #             username = 'pi',
1141             #             machine = 'kiosk.house',
1142             #             weight = 1,
1143             #             count = 2,
1144             #         ),
1145             #     )
1146             if self.ping('hero.house'):
1147                 logger.info('Found hero.house')
1148                 pool.append(
1149                     RemoteWorkerRecord(
1150                         username = 'scott',
1151                         machine = 'hero.house',
1152                         weight = 30,
1153                         count = 10,
1154                     ),
1155                 )
1156             if self.ping('puma.cabin'):
1157                 logger.info('Found puma.cabin')
1158                 pool.append(
1159                     RemoteWorkerRecord(
1160                         username = 'scott',
1161                         machine = 'puma.cabin',
1162                         weight = 25,
1163                         count = 6,
1164                     ),
1165                 )
1166             if self.ping('backup.house'):
1167                 logger.info('Found backup.house')
1168                 pool.append(
1169                     RemoteWorkerRecord(
1170                         username = 'scott',
1171                         machine = 'backup.house',
1172                         weight = 3,
1173                         count = 2,
1174                     ),
1175                 )
1176
1177             # The controller machine has a lot to do; go easy on it.
1178             for record in pool:
1179                 if record.machine == platform.node() and record.count > 1:
1180                     logger.info(f'Reducing workload for {record.machine}.')
1181                     record.count = 1
1182
1183             policy = WeightedRandomRemoteWorkerSelectionPolicy()
1184             policy.register_worker_pool(pool)
1185             self.remote_executor = RemoteExecutor(pool, policy)
1186         return self.remote_executor