3 # © Copyright 2021-2022, Scott Gasch
5 """This is a blueprint for training sklearn ML models."""
7 from __future__ import annotations
17 from abc import ABC, abstractmethod
18 from dataclasses import dataclass
19 from types import SimpleNamespace
20 from typing import Any, List, Optional, Set, Tuple
23 from sklearn.model_selection import train_test_split # type:ignore
24 from sklearn.preprocessing import MinMaxScaler # type: ignore
29 import parallelize as par
30 from ansi import bold, reset
31 from decorator_utils import timed
33 logger = logging.getLogger(__name__)
35 parser = config.add_commandline_args(
36 f"ML Model Trainer ({__file__})",
37 "Arguments related to training an ML model",
42 help="Don't prompt the user for anything.",
45 "--ml_trainer_delete",
47 help="Delete invalid/incomplete features files in addition to warning.",
49 group = parser.add_mutually_exclusive_group()
51 "--ml_trainer_dry_run",
53 help="Do not write a new model, just report efficacy.",
56 "--ml_trainer_persist_threshold",
57 type=argparse_utils.valid_percentage,
59 help="Persist the model if the test set score is >= this threshold.",
63 class InputSpec(SimpleNamespace):
64 """A collection of info needed to train the model provided by the
69 features_to_skip: Set[str]
70 key_value_delimiter: str
71 training_parameters: List
74 dry_run: Optional[bool]
76 persist_percentage_threshold: Optional[float]
77 delete_bad_inputs: Optional[bool]
80 def populate_from_config() -> InputSpec:
82 dry_run=config.config["ml_trainer_dry_run"],
83 quiet=config.config["ml_trainer_quiet"],
84 persist_percentage_threshold=config.config["ml_trainer_persist_threshold"],
85 delete_bad_inputs=config.config["ml_trainer_delete"],
91 """Info about the results of training returned to the caller."""
93 model_filename: Optional[str] = None
94 model_info_filename: Optional[str] = None
95 scaler_filename: Optional[str] = None
96 training_score: np.float64 = np.float64(0.0)
97 test_score: np.float64 = np.float64(0.0)
100 class TrainingBlueprint(ABC):
101 """The blueprint for doing the actual training."""
106 self.X_test_scaled = None
107 self.X_train_scaled = None
108 self.file_done_count = 0
109 self.total_file_count = 0
112 def train(self, spec: InputSpec) -> OutputSpec:
118 X_, y_ = self.read_input_files()
119 num_examples = len(y_)
121 # Every example's features
124 # Every example's label
127 print("Doing random test/train split...")
128 X_train, X_test, self.y_train, self.y_test = TrainingBlueprint.test_train_split(
133 print("Scaling training data...")
134 scaler, self.X_train_scaled, self.X_test_scaled = TrainingBlueprint.scale_data(
139 print("Training model(s)...")
141 modelid_to_params = {}
142 for params in self.spec.training_parameters:
143 model = self.train_model(params, self.X_train_scaled, self.y_train)
145 modelid_to_params[model.get_id()] = str(params)
149 best_score: Optional[np.float64] = None
150 best_test_score: Optional[np.float64] = None
151 best_training_score: Optional[np.float64] = None
153 for model in smart_future.wait_any(models):
154 params = modelid_to_params[model.get_id()]
155 if isinstance(model, smart_future.SmartFuture):
156 model = model._resolve()
157 if model is not None:
158 training_score, test_score = TrainingBlueprint.evaluate_model(
165 score = (training_score + test_score * 20) / 21
166 all_models[params] = (score, training_score, test_score)
167 if not self.spec.quiet:
169 f"{bold()}{params}{reset()}: "
170 f"Training set score={training_score:.2f}%, "
171 f"test set score={test_score:.2f}%",
174 if best_score is None or score > best_score:
176 best_test_score = test_score
177 best_training_score = training_score
180 if not self.spec.quiet:
181 print(f"New best score {best_score:.2f}% with params {params}")
183 executors.DefaultExecutors().shutdown()
184 assert best_training_score is not None
185 assert best_test_score is not None
186 assert best_params is not None
188 if not self.spec.quiet:
190 print('Done training...')
191 for params in all_models:
192 msg = f'{bold()}{params}{reset()}: score={all_models[params][0]:.2f}% '
193 msg += f'({all_models[params][2]:.2f}% test, '
194 msg += f'{all_models[params][1]:.2f}% train)'
195 if params == best_params:
196 msg += f'{bold()} <-- winner{reset()}'
203 ) = self.maybe_persist_scaler_and_model(
212 model_filename=model_filename,
213 model_info_filename=model_info_filename,
214 scaler_filename=scaler_filename,
215 training_score=best_training_score,
216 test_score=best_test_score,
219 @par.parallelize(method=par.Method.THREAD)
220 def read_files_from_list(self, files: List[str]) -> Tuple[List, List]:
227 for filename in files:
229 with open(filename, "r") as f:
230 lines = f.readlines()
232 # This example's features
236 # We expect lines in features files to be of the form:
241 (key, value) = line.split(self.spec.key_value_delimiter)
243 logger.debug("WARNING: bad line in file %s '%s', skipped", filename, line)
247 value = value.strip()
248 if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
249 logger.debug("Skipping feature %s", key)
252 value = TrainingBlueprint.normalize_feature(value)
254 if key == self.spec.label:
260 # Make sure we saw a label and the requisite number of features.
261 if len(x) == self.spec.feature_count and wrote_label:
263 self.file_done_count += 1
268 if self.spec.delete_bad_inputs:
269 msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. DELETING."
274 msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. Skipping."
279 def make_progress_graph(self) -> None:
280 if not self.spec.quiet:
281 from text_utils import bar_graph
283 bar_graph(self.file_done_count, self.total_file_count)
286 def read_input_files(self):
297 all_files = glob.glob(self.spec.file_glob)
298 self.total_file_count = len(all_files)
299 for files in list_utils.shard(all_files, 500):
300 file_list = list(files)
301 results.append(self.read_files_from_list(file_list))
303 for result in smart_future.wait_any(results, callback=self.make_progress_graph):
304 result = result._resolve()
309 if not self.spec.quiet:
310 print(" " * 80 + "\n")
314 def normalize_feature(value: str) -> Any:
315 if value in ("False", "None"):
317 elif value == "True":
319 elif isinstance(value, str) and "." in value:
320 ret = round(float(value) * 100.0)
326 def test_train_split(X, y) -> List:
327 logger.debug("Performing test/train split")
328 return train_test_split(
331 random_state=random.randrange(0, 1000),
335 def scale_data(X_train: np.ndarray, X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
336 logger.debug("Scaling data")
337 scaler = MinMaxScaler()
339 return (scaler, scaler.transform(X_train), scaler.transform(X_test))
341 # Note: children should implement. Consider using @parallelize.
343 def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
349 X_train_scaled: np.ndarray,
351 X_test_scaled: np.ndarray,
353 ) -> Tuple[np.float64, np.float64]:
354 logger.debug("Evaluating the model")
355 training_score = model.score(X_train_scaled, y_train) * 100.0
356 test_score = model.score(X_test_scaled, y_test) * 100.0
358 "Model evaluation results: test_score=%.5f, train_score=%.5f",
362 return (training_score, test_score)
364 def maybe_persist_scaler_and_model(
366 training_score: np.float64,
367 test_score: np.float64,
372 ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
373 if not self.spec.dry_run:
374 import datetime_utils
378 now: datetime.datetime = datetime_utils.now_pacific()
379 info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
380 Model params: {params}
381 Training examples: {num_examples}
382 Training set score: {training_score:.2f}%
383 Testing set score: {test_score:.2f}%"""
386 self.spec.persist_percentage_threshold is not None
387 and test_score > self.spec.persist_percentage_threshold
389 not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
391 scaler_filename = f"{self.spec.basename}_scaler.sav"
392 with open(scaler_filename, "wb") as fb:
393 pickle.dump(scaler, fb)
394 msg = f"Wrote {scaler_filename}"
397 model_filename = f"{self.spec.basename}_model.sav"
398 with open(model_filename, "wb") as fb:
399 pickle.dump(model, fb)
400 msg = f"Wrote {model_filename}"
403 model_info_filename = f"{self.spec.basename}_model_info.txt"
404 with open(model_info_filename, "w") as f:
406 msg = f"Wrote {model_info_filename}:"
409 print(string_utils.indent(info, 2))
411 return (scaler_filename, model_filename, model_info_filename)
412 return (None, None, None)