Let's call the base class' c'tor first, eh?
[python_utils.git] / ml / model_trainer.py
index 213a1814cff5e98507e30c19e17669ab123886ce..a37885ce3e92f20d11bb61f4268407689829bca4 100644 (file)
@@ -2,7 +2,6 @@
 
 from __future__ import annotations
 
-from abc import ABC, abstractmethod
 import datetime
 import glob
 import logging
@@ -10,20 +9,21 @@ import os
 import pickle
 import random
 import sys
+import warnings
+from abc import ABC, abstractmethod
 from types import SimpleNamespace
 from typing import Any, List, NamedTuple, Optional, Set, Tuple
-import warnings
 
 import numpy as np
 from sklearn.model_selection import train_test_split  # type:ignore
 from sklearn.preprocessing import MinMaxScaler  # type: ignore
 
-from ansi import bold, reset
 import argparse_utils
 import config
-from decorator_utils import timed
 import executors
 import parallelize as par
+from ansi import bold, reset
+from decorator_utils import timed
 
 logger = logging.getLogger(__file__)
 
@@ -81,8 +81,8 @@ class OutputSpec(NamedTuple):
     model_filename: Optional[str]
     model_info_filename: Optional[str]
     scaler_filename: Optional[str]
-    training_score: float
-    test_score: float
+    training_score: np.float64
+    test_score: np.float64
 
 
 class TrainingBlueprint(ABC):
@@ -131,9 +131,9 @@ class TrainingBlueprint(ABC):
             modelid_to_params[model.get_id()] = str(params)
 
         best_model = None
-        best_score = None
-        best_test_score = None
-        best_training_score = None
+        best_score: Optional[np.float64] = None
+        best_test_score: Optional[np.float64] = None
+        best_training_score: Optional[np.float64] = None
         best_params = None
         for model in smart_future.wait_any(models):
             params = modelid_to_params[model.get_id()]
@@ -170,6 +170,9 @@ class TrainingBlueprint(ABC):
             print(msg)
             logger.info(msg)
 
+        assert best_training_score is not None
+        assert best_test_score is not None
+        assert best_params is not None
         (
             scaler_filename,
             model_filename,
@@ -369,14 +372,14 @@ Testing set score: {test_score:.2f}%"""
                 and input_utils.yn_response("Write the model? [y,n]: ") == "y"
             ):
                 scaler_filename = f"{self.spec.basename}_scaler.sav"
-                with open(scaler_filename, "wb") as f:
-                    pickle.dump(scaler, f)
+                with open(scaler_filename, "wb") as fb:
+                    pickle.dump(scaler, fb)
                 msg = f"Wrote {scaler_filename}"
                 print(msg)
                 logger.info(msg)
                 model_filename = f"{self.spec.basename}_model.sav"
-                with open(model_filename, "wb") as f:
-                    pickle.dump(model, f)
+                with open(model_filename, "wb") as fb:
+                    pickle.dump(model, fb)
                 msg = f"Wrote {model_filename}"
                 print(msg)
                 logger.info(msg)