Change settings in flake8 and black.
[python_utils.git] / ml / model_trainer.py
index a37885ce3e92f20d11bb61f4268407689829bca4..12ccb3c6c0508e61081d6791b12f9a8f5a5571f2 100644 (file)
@@ -28,7 +28,8 @@ from decorator_utils import timed
 logger = logging.getLogger(__file__)
 
 parser = config.add_commandline_args(
-    f"ML Model Trainer ({__file__})", "Arguments related to training an ML model"
+    f"ML Model Trainer ({__file__})",
+    "Arguments related to training an ML model",
 )
 parser.add_argument(
     "--ml_trainer_quiet",
@@ -217,17 +218,12 @@ class TrainingBlueprint(ABC):
                 try:
                     (key, value) = line.split(self.spec.key_value_delimiter)
                 except Exception:
-                    logger.debug(
-                        f"WARNING: bad line in file {filename} '{line}', skipped"
-                    )
+                    logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
                     continue
 
                 key = key.strip()
                 value = value.strip()
-                if (
-                    self.spec.features_to_skip is not None
-                    and key in self.spec.features_to_skip
-                ):
+                if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
                     logger.debug(f"Skipping feature {key}")
                     continue
 
@@ -321,9 +317,7 @@ class TrainingBlueprint(ABC):
 
     # Note: children should implement.  Consider using @parallelize.
     @abstractmethod
-    def train_model(
-        self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray
-    ) -> Any:
+    def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
         pass
 
     def evaluate_model(
@@ -368,8 +362,7 @@ Testing set score: {test_score:.2f}%"""
                 self.spec.persist_percentage_threshold is not None
                 and test_score > self.spec.persist_percentage_threshold
             ) or (
-                not self.spec.quiet
-                and input_utils.yn_response("Write the model? [y,n]: ") == "y"
+                not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
             ):
                 scaler_filename = f"{self.spec.basename}_scaler.sav"
                 with open(scaler_filename, "wb") as fb: