X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=ml%2Fmodel_trainer.py;h=12ccb3c6c0508e61081d6791b12f9a8f5a5571f2;hb=713a609bd19d491de03debf8a4a6ddf2540b13dc;hp=213a1814cff5e98507e30c19e17669ab123886ce;hpb=36fe954a689c26e7082c61c1c8dbbf76dd7cf6c8;p=python_utils.git diff --git a/ml/model_trainer.py b/ml/model_trainer.py index 213a181..12ccb3c 100644 --- a/ml/model_trainer.py +++ b/ml/model_trainer.py @@ -2,7 +2,6 @@ from __future__ import annotations -from abc import ABC, abstractmethod import datetime import glob import logging @@ -10,25 +9,27 @@ import os import pickle import random import sys +import warnings +from abc import ABC, abstractmethod from types import SimpleNamespace from typing import Any, List, NamedTuple, Optional, Set, Tuple -import warnings import numpy as np from sklearn.model_selection import train_test_split # type:ignore from sklearn.preprocessing import MinMaxScaler # type: ignore -from ansi import bold, reset import argparse_utils import config -from decorator_utils import timed import executors import parallelize as par +from ansi import bold, reset +from decorator_utils import timed logger = logging.getLogger(__file__) parser = config.add_commandline_args( - f"ML Model Trainer ({__file__})", "Arguments related to training an ML model" + f"ML Model Trainer ({__file__})", + "Arguments related to training an ML model", ) parser.add_argument( "--ml_trainer_quiet", @@ -81,8 +82,8 @@ class OutputSpec(NamedTuple): model_filename: Optional[str] model_info_filename: Optional[str] scaler_filename: Optional[str] - training_score: float - test_score: float + training_score: np.float64 + test_score: np.float64 class TrainingBlueprint(ABC): @@ -131,9 +132,9 @@ class TrainingBlueprint(ABC): modelid_to_params[model.get_id()] = str(params) best_model = None - best_score = None - best_test_score = None - best_training_score = None + best_score: Optional[np.float64] = None + best_test_score: Optional[np.float64] = None + best_training_score: Optional[np.float64] = None best_params = None for model in smart_future.wait_any(models): params = modelid_to_params[model.get_id()] @@ -170,6 +171,9 @@ class TrainingBlueprint(ABC): print(msg) logger.info(msg) + assert best_training_score is not None + assert best_test_score is not None + assert best_params is not None ( scaler_filename, model_filename, @@ -214,17 +218,12 @@ class TrainingBlueprint(ABC): try: (key, value) = line.split(self.spec.key_value_delimiter) except Exception: - logger.debug( - f"WARNING: bad line in file {filename} '{line}', skipped" - ) + logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped") continue key = key.strip() value = value.strip() - if ( - self.spec.features_to_skip is not None - and key in self.spec.features_to_skip - ): + if self.spec.features_to_skip is not None and key in self.spec.features_to_skip: logger.debug(f"Skipping feature {key}") continue @@ -318,9 +317,7 @@ class TrainingBlueprint(ABC): # Note: children should implement. Consider using @parallelize. @abstractmethod - def train_model( - self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray - ) -> Any: + def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any: pass def evaluate_model( @@ -365,18 +362,17 @@ Testing set score: {test_score:.2f}%""" self.spec.persist_percentage_threshold is not None and test_score > self.spec.persist_percentage_threshold ) or ( - not self.spec.quiet - and input_utils.yn_response("Write the model? [y,n]: ") == "y" + not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y" ): scaler_filename = f"{self.spec.basename}_scaler.sav" - with open(scaler_filename, "wb") as f: - pickle.dump(scaler, f) + with open(scaler_filename, "wb") as fb: + pickle.dump(scaler, fb) msg = f"Wrote {scaler_filename}" print(msg) logger.info(msg) model_filename = f"{self.spec.basename}_model.sav" - with open(model_filename, "wb") as f: - pickle.dump(model, f) + with open(model_filename, "wb") as fb: + pickle.dump(model, fb) msg = f"Wrote {model_filename}" print(msg) logger.info(msg)