Change settings in flake8 and black.
[python_utils.git] / ml / model_trainer.py
index 213a1814cff5e98507e30c19e17669ab123886ce..12ccb3c6c0508e61081d6791b12f9a8f5a5571f2 100644 (file)
@@ -2,7 +2,6 @@
 
 from __future__ import annotations
 
 
 from __future__ import annotations
 
-from abc import ABC, abstractmethod
 import datetime
 import glob
 import logging
 import datetime
 import glob
 import logging
@@ -10,25 +9,27 @@ import os
 import pickle
 import random
 import sys
 import pickle
 import random
 import sys
+import warnings
+from abc import ABC, abstractmethod
 from types import SimpleNamespace
 from typing import Any, List, NamedTuple, Optional, Set, Tuple
 from types import SimpleNamespace
 from typing import Any, List, NamedTuple, Optional, Set, Tuple
-import warnings
 
 import numpy as np
 from sklearn.model_selection import train_test_split  # type:ignore
 from sklearn.preprocessing import MinMaxScaler  # type: ignore
 
 
 import numpy as np
 from sklearn.model_selection import train_test_split  # type:ignore
 from sklearn.preprocessing import MinMaxScaler  # type: ignore
 
-from ansi import bold, reset
 import argparse_utils
 import config
 import argparse_utils
 import config
-from decorator_utils import timed
 import executors
 import parallelize as par
 import executors
 import parallelize as par
+from ansi import bold, reset
+from decorator_utils import timed
 
 logger = logging.getLogger(__file__)
 
 parser = config.add_commandline_args(
 
 logger = logging.getLogger(__file__)
 
 parser = config.add_commandline_args(
-    f"ML Model Trainer ({__file__})", "Arguments related to training an ML model"
+    f"ML Model Trainer ({__file__})",
+    "Arguments related to training an ML model",
 )
 parser.add_argument(
     "--ml_trainer_quiet",
 )
 parser.add_argument(
     "--ml_trainer_quiet",
@@ -81,8 +82,8 @@ class OutputSpec(NamedTuple):
     model_filename: Optional[str]
     model_info_filename: Optional[str]
     scaler_filename: Optional[str]
     model_filename: Optional[str]
     model_info_filename: Optional[str]
     scaler_filename: Optional[str]
-    training_score: float
-    test_score: float
+    training_score: np.float64
+    test_score: np.float64
 
 
 class TrainingBlueprint(ABC):
 
 
 class TrainingBlueprint(ABC):
@@ -131,9 +132,9 @@ class TrainingBlueprint(ABC):
             modelid_to_params[model.get_id()] = str(params)
 
         best_model = None
             modelid_to_params[model.get_id()] = str(params)
 
         best_model = None
-        best_score = None
-        best_test_score = None
-        best_training_score = None
+        best_score: Optional[np.float64] = None
+        best_test_score: Optional[np.float64] = None
+        best_training_score: Optional[np.float64] = None
         best_params = None
         for model in smart_future.wait_any(models):
             params = modelid_to_params[model.get_id()]
         best_params = None
         for model in smart_future.wait_any(models):
             params = modelid_to_params[model.get_id()]
@@ -170,6 +171,9 @@ class TrainingBlueprint(ABC):
             print(msg)
             logger.info(msg)
 
             print(msg)
             logger.info(msg)
 
+        assert best_training_score is not None
+        assert best_test_score is not None
+        assert best_params is not None
         (
             scaler_filename,
             model_filename,
         (
             scaler_filename,
             model_filename,
@@ -214,17 +218,12 @@ class TrainingBlueprint(ABC):
                 try:
                     (key, value) = line.split(self.spec.key_value_delimiter)
                 except Exception:
                 try:
                     (key, value) = line.split(self.spec.key_value_delimiter)
                 except Exception:
-                    logger.debug(
-                        f"WARNING: bad line in file {filename} '{line}', skipped"
-                    )
+                    logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
                     continue
 
                 key = key.strip()
                 value = value.strip()
                     continue
 
                 key = key.strip()
                 value = value.strip()
-                if (
-                    self.spec.features_to_skip is not None
-                    and key in self.spec.features_to_skip
-                ):
+                if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
                     logger.debug(f"Skipping feature {key}")
                     continue
 
                     logger.debug(f"Skipping feature {key}")
                     continue
 
@@ -318,9 +317,7 @@ class TrainingBlueprint(ABC):
 
     # Note: children should implement.  Consider using @parallelize.
     @abstractmethod
 
     # Note: children should implement.  Consider using @parallelize.
     @abstractmethod
-    def train_model(
-        self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray
-    ) -> Any:
+    def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
         pass
 
     def evaluate_model(
         pass
 
     def evaluate_model(
@@ -365,18 +362,17 @@ Testing set score: {test_score:.2f}%"""
                 self.spec.persist_percentage_threshold is not None
                 and test_score > self.spec.persist_percentage_threshold
             ) or (
                 self.spec.persist_percentage_threshold is not None
                 and test_score > self.spec.persist_percentage_threshold
             ) or (
-                not self.spec.quiet
-                and input_utils.yn_response("Write the model? [y,n]: ") == "y"
+                not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
             ):
                 scaler_filename = f"{self.spec.basename}_scaler.sav"
             ):
                 scaler_filename = f"{self.spec.basename}_scaler.sav"
-                with open(scaler_filename, "wb") as f:
-                    pickle.dump(scaler, f)
+                with open(scaler_filename, "wb") as fb:
+                    pickle.dump(scaler, fb)
                 msg = f"Wrote {scaler_filename}"
                 print(msg)
                 logger.info(msg)
                 model_filename = f"{self.spec.basename}_model.sav"
                 msg = f"Wrote {scaler_filename}"
                 print(msg)
                 logger.info(msg)
                 model_filename = f"{self.spec.basename}_model.sav"
-                with open(model_filename, "wb") as f:
-                    pickle.dump(model, f)
+                with open(model_filename, "wb") as fb:
+                    pickle.dump(model, fb)
                 msg = f"Wrote {model_filename}"
                 print(msg)
                 logger.info(msg)
                 msg = f"Wrote {model_filename}"
                 print(msg)
                 logger.info(msg)