From: Scott Gasch Date: Tue, 9 May 2023 21:09:37 +0000 (-0700) Subject: Fix type hints in executors. X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=c94f51d840423854dd87aa99e1634fdc5a770a88;p=pyutils.git Fix type hints in executors. --- diff --git a/src/pyutils/parallelize/executors.py b/src/pyutils/parallelize/executors.py index fd8cc7c..04d6a80 100644 --- a/src/pyutils/parallelize/executors.py +++ b/src/pyutils/parallelize/executors.py @@ -678,7 +678,9 @@ class RemoteWorkerSelectionPolicy(ABC): pass @abstractmethod - def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]: + def acquire_worker( + self, machine_to_avoid: str = None + ) -> Optional[RemoteWorkerRecord]: pass @@ -694,7 +696,9 @@ class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy): return False @overrides - def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]: + def acquire_worker( + self, machine_to_avoid: str = None + ) -> Optional[RemoteWorkerRecord]: grabbag = [] if self.workers: for worker in self.workers: @@ -739,6 +743,13 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy): return True return False + def _increment_index(self, index: int) -> None: + if self.workers: + index += 1 + if index >= len(self.workers): + index = 0 + self.index = index + @overrides def acquire_worker( self, machine_to_avoid: str = None @@ -747,12 +758,9 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy): x = self.index while True: worker = self.workers[x] - if worker.count > 0: + if worker.machine != machine_to_avoid and worker.count > 0: worker.count -= 1 - x += 1 - if x >= len(self.workers): - x = 0 - self.index = x + self._increment_index(x) logger.debug('Selected worker %s', worker) return worker x += 1