Change settings in flake8 and black.
[python_utils.git] / ml / model_trainer.py
index 213a1814cff5e98507e30c19e17669ab123886ce..12ccb3c6c0508e61081d6791b12f9a8f5a5571f2 100644 (file)
@@ -2,7 +2,6 @@
 
 from __future__ import annotations
 
-from abc import ABC, abstractmethod
 import datetime
 import glob
 import logging
@@ -10,25 +9,27 @@ import os
 import pickle
 import random
 import sys
+import warnings
+from abc import ABC, abstractmethod
 from types import SimpleNamespace
 from typing import Any, List, NamedTuple, Optional, Set, Tuple
-import warnings
 
 import numpy as np
 from sklearn.model_selection import train_test_split  # type:ignore
 from sklearn.preprocessing import MinMaxScaler  # type: ignore
 
-from ansi import bold, reset
 import argparse_utils
 import config
-from decorator_utils import timed
 import executors
 import parallelize as par
+from ansi import bold, reset
+from decorator_utils import timed
 
 logger = logging.getLogger(__file__)
 
 parser = config.add_commandline_args(
-    f"ML Model Trainer ({__file__})", "Arguments related to training an ML model"
+    f"ML Model Trainer ({__file__})",
+    "Arguments related to training an ML model",
 )
 parser.add_argument(
     "--ml_trainer_quiet",
@@ -81,8 +82,8 @@ class OutputSpec(NamedTuple):
     model_filename: Optional[str]
     model_info_filename: Optional[str]
     scaler_filename: Optional[str]
-    training_score: float
-    test_score: float
+    training_score: np.float64
+    test_score: np.float64
 
 
 class TrainingBlueprint(ABC):
@@ -131,9 +132,9 @@ class TrainingBlueprint(ABC):
             modelid_to_params[model.get_id()] = str(params)
 
         best_model = None
-        best_score = None
-        best_test_score = None
-        best_training_score = None
+        best_score: Optional[np.float64] = None
+        best_test_score: Optional[np.float64] = None
+        best_training_score: Optional[np.float64] = None
         best_params = None
         for model in smart_future.wait_any(models):
             params = modelid_to_params[model.get_id()]
@@ -170,6 +171,9 @@ class TrainingBlueprint(ABC):
             print(msg)
             logger.info(msg)
 
+        assert best_training_score is not None
+        assert best_test_score is not None
+        assert best_params is not None
         (
             scaler_filename,
             model_filename,
@@ -214,17 +218,12 @@ class TrainingBlueprint(ABC):
                 try:
                     (key, value) = line.split(self.spec.key_value_delimiter)
                 except Exception:
-                    logger.debug(
-                        f"WARNING: bad line in file {filename} '{line}', skipped"
-                    )
+                    logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
                     continue
 
                 key = key.strip()
                 value = value.strip()
-                if (
-                    self.spec.features_to_skip is not None
-                    and key in self.spec.features_to_skip
-                ):
+                if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
                     logger.debug(f"Skipping feature {key}")
                     continue
 
@@ -318,9 +317,7 @@ class TrainingBlueprint(ABC):
 
     # Note: children should implement.  Consider using @parallelize.
     @abstractmethod
-    def train_model(
-        self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray
-    ) -> Any:
+    def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
         pass
 
     def evaluate_model(
@@ -365,18 +362,17 @@ Testing set score: {test_score:.2f}%"""
                 self.spec.persist_percentage_threshold is not None
                 and test_score > self.spec.persist_percentage_threshold
             ) or (
-                not self.spec.quiet
-                and input_utils.yn_response("Write the model? [y,n]: ") == "y"
+                not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
             ):
                 scaler_filename = f"{self.spec.basename}_scaler.sav"
-                with open(scaler_filename, "wb") as f:
-                    pickle.dump(scaler, f)
+                with open(scaler_filename, "wb") as fb:
+                    pickle.dump(scaler, fb)
                 msg = f"Wrote {scaler_filename}"
                 print(msg)
                 logger.info(msg)
                 model_filename = f"{self.spec.basename}_model.sav"
-                with open(model_filename, "wb") as f:
-                    pickle.dump(model, f)
+                with open(model_filename, "wb") as fb:
+                    pickle.dump(model, fb)
                 msg = f"Wrote {model_filename}"
                 print(msg)
                 logger.info(msg)