Fix type hints in executors.
authorScott Gasch <[email protected]>
Tue, 9 May 2023 21:09:37 +0000 (14:09 -0700)
committerScott Gasch <[email protected]>
Tue, 9 May 2023 21:09:37 +0000 (14:09 -0700)
src/pyutils/parallelize/executors.py

index fd8cc7cdc31ce221059e66bcc68f53873135d4cb..04d6a80ecdc6a78e46c487eaa9b43891f5507ab3 100644 (file)
@@ -678,7 +678,9 @@ class RemoteWorkerSelectionPolicy(ABC):
         pass
 
     @abstractmethod
-    def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
+    def acquire_worker(
+        self, machine_to_avoid: str = None
+    ) -> Optional[RemoteWorkerRecord]:
         pass
 
 
@@ -694,7 +696,9 @@ class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
         return False
 
     @overrides
-    def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]:
+    def acquire_worker(
+        self, machine_to_avoid: str = None
+    ) -> Optional[RemoteWorkerRecord]:
         grabbag = []
         if self.workers:
             for worker in self.workers:
@@ -739,6 +743,13 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
                     return True
         return False
 
+    def _increment_index(self, index: int) -> None:
+        if self.workers:
+            index += 1
+            if index >= len(self.workers):
+                index = 0
+            self.index = index
+
     @overrides
     def acquire_worker(
         self, machine_to_avoid: str = None
@@ -747,12 +758,9 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy):
             x = self.index
             while True:
                 worker = self.workers[x]
-                if worker.count > 0:
+                if worker.machine != machine_to_avoid and worker.count > 0:
                     worker.count -= 1
-                    x += 1
-                    if x >= len(self.workers):
-                        x = 0
-                    self.index = x
+                    self._increment_index(x)
                     logger.debug('Selected worker %s', worker)
                     return worker
                 x += 1