3 # © Copyright 2021-2022, Scott Gasch
5 """This is a blueprint for training sklearn ML models."""
7 from __future__ import annotations
16 from abc import ABC, abstractmethod
17 from dataclasses import dataclass
18 from types import SimpleNamespace
19 from typing import Any, List, Optional, Set, Tuple
22 from sklearn.model_selection import train_test_split # type:ignore
23 from sklearn.preprocessing import MinMaxScaler # type: ignore
28 import parallelize as par
29 from ansi import bold, reset
30 from decorator_utils import timed
32 logger = logging.getLogger(__name__)
34 parser = config.add_commandline_args(
35 f"ML Model Trainer ({__file__})",
36 "Arguments related to training an ML model",
41 help="Don't prompt the user for anything.",
44 "--ml_trainer_delete",
46 help="Delete invalid/incomplete features files in addition to warning.",
48 group = parser.add_mutually_exclusive_group()
50 "--ml_trainer_dry_run",
52 help="Do not write a new model, just report efficacy.",
55 "--ml_trainer_persist_threshold",
56 type=argparse_utils.valid_percentage,
58 help="Persist the model if the test set score is >= this threshold.",
62 class InputSpec(SimpleNamespace):
63 """A collection of info needed to train the model provided by the
68 features_to_skip: Set[str]
69 key_value_delimiter: str
70 training_parameters: List
73 dry_run: Optional[bool]
75 persist_percentage_threshold: Optional[float]
76 delete_bad_inputs: Optional[bool]
79 def populate_from_config() -> InputSpec:
81 dry_run=config.config["ml_trainer_dry_run"],
82 quiet=config.config["ml_trainer_quiet"],
83 persist_percentage_threshold=config.config["ml_trainer_persist_threshold"],
84 delete_bad_inputs=config.config["ml_trainer_delete"],
90 """Info about the results of training returned to the caller."""
92 model_filename: Optional[str] = None
93 model_info_filename: Optional[str] = None
94 scaler_filename: Optional[str] = None
95 training_score: np.float64 = np.float64(0.0)
96 test_score: np.float64 = np.float64(0.0)
99 class TrainingBlueprint(ABC):
100 """The blueprint for doing the actual training."""
105 self.X_test_scaled = None
106 self.X_train_scaled = None
107 self.file_done_count = 0
108 self.total_file_count = 0
111 def train(self, spec: InputSpec) -> OutputSpec:
117 X_, y_ = self.read_input_files()
118 num_examples = len(y_)
120 # Every example's features
123 # Every example's label
126 print("Doing random test/train split...")
127 X_train, X_test, self.y_train, self.y_test = TrainingBlueprint.test_train_split(
132 print("Scaling training data...")
133 scaler, self.X_train_scaled, self.X_test_scaled = TrainingBlueprint.scale_data(
138 print("Training model(s)...")
140 modelid_to_params = {}
141 for params in self.spec.training_parameters:
142 model = self.train_model(params, self.X_train_scaled, self.y_train)
144 modelid_to_params[model.get_id()] = str(params)
147 best_score: Optional[np.float64] = None
148 best_test_score: Optional[np.float64] = None
149 best_training_score: Optional[np.float64] = None
151 for model in smart_future.wait_any(models):
152 params = modelid_to_params[model.get_id()]
153 if isinstance(model, smart_future.SmartFuture):
154 model = model._resolve()
155 if model is not None:
156 training_score, test_score = TrainingBlueprint.evaluate_model(
163 score = (training_score + test_score * 20) / 21
164 if not self.spec.quiet:
166 f"{bold()}{params}{reset()}: "
167 f"Training set score={training_score:.2f}%, "
168 f"test set score={test_score:.2f}%",
171 if best_score is None or score > best_score:
173 best_test_score = test_score
174 best_training_score = training_score
177 if not self.spec.quiet:
178 print(f"New best score {best_score:.2f}% with params {params}")
180 if not self.spec.quiet:
181 executors.DefaultExecutors().shutdown()
182 msg = f"Done training; best test set score was: {best_test_score:.1f}%"
186 assert best_training_score is not None
187 assert best_test_score is not None
188 assert best_params is not None
193 ) = self.maybe_persist_scaler_and_model(
202 model_filename=model_filename,
203 model_info_filename=model_info_filename,
204 scaler_filename=scaler_filename,
205 training_score=best_training_score,
206 test_score=best_test_score,
209 @par.parallelize(method=par.Method.THREAD)
210 def read_files_from_list(self, files: List[str]) -> Tuple[List, List]:
217 for filename in files:
219 with open(filename, "r") as f:
220 lines = f.readlines()
222 # This example's features
226 # We expect lines in features files to be of the form:
231 (key, value) = line.split(self.spec.key_value_delimiter)
233 logger.debug("WARNING: bad line in file %s '%s', skipped", filename, line)
237 value = value.strip()
238 if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
239 logger.debug("Skipping feature %s", key)
242 value = TrainingBlueprint.normalize_feature(value)
244 if key == self.spec.label:
250 # Make sure we saw a label and the requisite number of features.
251 if len(x) == self.spec.feature_count and wrote_label:
253 self.file_done_count += 1
258 if self.spec.delete_bad_inputs:
259 msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. DELETING."
264 msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. Skipping."
269 def make_progress_graph(self) -> None:
270 if not self.spec.quiet:
271 from text_utils import progress_graph
273 progress_graph(self.file_done_count, self.total_file_count)
276 def read_input_files(self):
287 all_files = glob.glob(self.spec.file_glob)
288 self.total_file_count = len(all_files)
289 for files in list_utils.shard(all_files, 500):
290 file_list = list(files)
291 results.append(self.read_files_from_list(file_list))
293 for result in smart_future.wait_any(results, callback=self.make_progress_graph):
294 result = result._resolve()
299 if not self.spec.quiet:
300 print(" " * 80 + "\n")
304 def normalize_feature(value: str) -> Any:
305 if value in ("False", "None"):
307 elif value == "True":
309 elif isinstance(value, str) and "." in value:
310 ret = round(float(value) * 100.0)
316 def test_train_split(X, y) -> List:
317 logger.debug("Performing test/train split")
318 return train_test_split(
321 random_state=random.randrange(0, 1000),
325 def scale_data(X_train: np.ndarray, X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
326 logger.debug("Scaling data")
327 scaler = MinMaxScaler()
329 return (scaler, scaler.transform(X_train), scaler.transform(X_test))
331 # Note: children should implement. Consider using @parallelize.
333 def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
339 X_train_scaled: np.ndarray,
341 X_test_scaled: np.ndarray,
343 ) -> Tuple[np.float64, np.float64]:
344 logger.debug("Evaluating the model")
345 training_score = model.score(X_train_scaled, y_train) * 100.0
346 test_score = model.score(X_test_scaled, y_test) * 100.0
348 "Model evaluation results: test_score=%.5f, train_score=%.5f",
352 return (training_score, test_score)
354 def maybe_persist_scaler_and_model(
356 training_score: np.float64,
357 test_score: np.float64,
362 ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
363 if not self.spec.dry_run:
364 import datetime_utils
368 now: datetime.datetime = datetime_utils.now_pacific()
369 info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
370 Model params: {params}
371 Training examples: {num_examples}
372 Training set score: {training_score:.2f}%
373 Testing set score: {test_score:.2f}%"""
376 self.spec.persist_percentage_threshold is not None
377 and test_score > self.spec.persist_percentage_threshold
379 not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
381 scaler_filename = f"{self.spec.basename}_scaler.sav"
382 with open(scaler_filename, "wb") as fb:
383 pickle.dump(scaler, fb)
384 msg = f"Wrote {scaler_filename}"
387 model_filename = f"{self.spec.basename}_model.sav"
388 with open(model_filename, "wb") as fb:
389 pickle.dump(model, fb)
390 msg = f"Wrote {model_filename}"
393 model_info_filename = f"{self.spec.basename}_model_info.txt"
394 with open(model_info_filename, "w") as f:
396 msg = f"Wrote {model_info_filename}:"
399 print(string_utils.indent(info, 2))
401 return (scaler_filename, model_filename, model_info_filename)
402 return (None, None, None)