X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=ml%2Fmodel_trainer.py;h=12ccb3c6c0508e61081d6791b12f9a8f5a5571f2;hb=713a609bd19d491de03debf8a4a6ddf2540b13dc;hp=a37885ce3e92f20d11bb61f4268407689829bca4;hpb=7ff2af6fe7bffea90dc4a31c93140c189917c659;p=python_utils.git diff --git a/ml/model_trainer.py b/ml/model_trainer.py index a37885c..12ccb3c 100644 --- a/ml/model_trainer.py +++ b/ml/model_trainer.py @@ -28,7 +28,8 @@ from decorator_utils import timed logger = logging.getLogger(__file__) parser = config.add_commandline_args( - f"ML Model Trainer ({__file__})", "Arguments related to training an ML model" + f"ML Model Trainer ({__file__})", + "Arguments related to training an ML model", ) parser.add_argument( "--ml_trainer_quiet", @@ -217,17 +218,12 @@ class TrainingBlueprint(ABC): try: (key, value) = line.split(self.spec.key_value_delimiter) except Exception: - logger.debug( - f"WARNING: bad line in file {filename} '{line}', skipped" - ) + logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped") continue key = key.strip() value = value.strip() - if ( - self.spec.features_to_skip is not None - and key in self.spec.features_to_skip - ): + if self.spec.features_to_skip is not None and key in self.spec.features_to_skip: logger.debug(f"Skipping feature {key}") continue @@ -321,9 +317,7 @@ class TrainingBlueprint(ABC): # Note: children should implement. Consider using @parallelize. @abstractmethod - def train_model( - self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray - ) -> Any: + def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any: pass def evaluate_model( @@ -368,8 +362,7 @@ Testing set score: {test_score:.2f}%""" self.spec.persist_percentage_threshold is not None and test_score > self.spec.persist_percentage_threshold ) or ( - not self.spec.quiet - and input_utils.yn_response("Write the model? [y,n]: ") == "y" + not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y" ): scaler_filename = f"{self.spec.basename}_scaler.sav" with open(scaler_filename, "wb") as fb: