3 """This is a blueprint for training sklearn ML models."""
5 from __future__ import annotations
14 from abc import ABC, abstractmethod
15 from dataclasses import dataclass
16 from types import SimpleNamespace
17 from typing import Any, List, Optional, Set, Tuple
20 from sklearn.model_selection import train_test_split # type:ignore
21 from sklearn.preprocessing import MinMaxScaler # type: ignore
26 import parallelize as par
27 from ansi import bold, reset
28 from decorator_utils import timed
30 logger = logging.getLogger(__name__)
32 parser = config.add_commandline_args(
33 f"ML Model Trainer ({__file__})",
34 "Arguments related to training an ML model",
39 help="Don't prompt the user for anything.",
42 "--ml_trainer_delete",
44 help="Delete invalid/incomplete features files in addition to warning.",
46 group = parser.add_mutually_exclusive_group()
48 "--ml_trainer_dry_run",
50 help="Do not write a new model, just report efficacy.",
53 "--ml_trainer_persist_threshold",
54 type=argparse_utils.valid_percentage,
56 help="Persist the model if the test set score is >= this threshold.",
60 class InputSpec(SimpleNamespace):
61 """A collection of info needed to train the model provided by the
66 features_to_skip: Set[str]
67 key_value_delimiter: str
68 training_parameters: List
71 dry_run: Optional[bool]
73 persist_percentage_threshold: Optional[float]
74 delete_bad_inputs: Optional[bool]
77 def populate_from_config() -> InputSpec:
79 dry_run=config.config["ml_trainer_dry_run"],
80 quiet=config.config["ml_trainer_quiet"],
81 persist_percentage_threshold=config.config["ml_trainer_persist_threshold"],
82 delete_bad_inputs=config.config["ml_trainer_delete"],
88 """Info about the results of training returned to the caller."""
90 model_filename: Optional[str] = None
91 model_info_filename: Optional[str] = None
92 scaler_filename: Optional[str] = None
93 training_score: np.float64 = np.float64(0.0)
94 test_score: np.float64 = np.float64(0.0)
97 class TrainingBlueprint(ABC):
98 """The blueprint for doing the actual training."""
103 self.X_test_scaled = None
104 self.X_train_scaled = None
105 self.file_done_count = 0
106 self.total_file_count = 0
109 def train(self, spec: InputSpec) -> OutputSpec:
115 X_, y_ = self.read_input_files()
116 num_examples = len(y_)
118 # Every example's features
121 # Every example's label
124 print("Doing random test/train split...")
125 X_train, X_test, self.y_train, self.y_test = TrainingBlueprint.test_train_split(
130 print("Scaling training data...")
131 scaler, self.X_train_scaled, self.X_test_scaled = TrainingBlueprint.scale_data(
136 print("Training model(s)...")
138 modelid_to_params = {}
139 for params in self.spec.training_parameters:
140 model = self.train_model(params, self.X_train_scaled, self.y_train)
142 modelid_to_params[model.get_id()] = str(params)
145 best_score: Optional[np.float64] = None
146 best_test_score: Optional[np.float64] = None
147 best_training_score: Optional[np.float64] = None
149 for model in smart_future.wait_any(models):
150 params = modelid_to_params[model.get_id()]
151 if isinstance(model, smart_future.SmartFuture):
152 model = model._resolve()
153 if model is not None:
154 training_score, test_score = TrainingBlueprint.evaluate_model(
161 score = (training_score + test_score * 20) / 21
162 if not self.spec.quiet:
164 f"{bold()}{params}{reset()}: "
165 f"Training set score={training_score:.2f}%, "
166 f"test set score={test_score:.2f}%",
169 if best_score is None or score > best_score:
171 best_test_score = test_score
172 best_training_score = training_score
175 if not self.spec.quiet:
176 print(f"New best score {best_score:.2f}% with params {params}")
178 if not self.spec.quiet:
179 executors.DefaultExecutors().shutdown()
180 msg = f"Done training; best test set score was: {best_test_score:.1f}%"
184 assert best_training_score is not None
185 assert best_test_score is not None
186 assert best_params is not None
191 ) = self.maybe_persist_scaler_and_model(
200 model_filename=model_filename,
201 model_info_filename=model_info_filename,
202 scaler_filename=scaler_filename,
203 training_score=best_training_score,
204 test_score=best_test_score,
207 @par.parallelize(method=par.Method.THREAD)
208 def read_files_from_list(self, files: List[str]) -> Tuple[List, List]:
215 for filename in files:
217 with open(filename, "r") as f:
218 lines = f.readlines()
220 # This example's features
224 # We expect lines in features files to be of the form:
229 (key, value) = line.split(self.spec.key_value_delimiter)
231 logger.debug("WARNING: bad line in file %s '%s', skipped", filename, line)
235 value = value.strip()
236 if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
237 logger.debug("Skipping feature %s", key)
240 value = TrainingBlueprint.normalize_feature(value)
242 if key == self.spec.label:
248 # Make sure we saw a label and the requisite number of features.
249 if len(x) == self.spec.feature_count and wrote_label:
251 self.file_done_count += 1
256 if self.spec.delete_bad_inputs:
257 msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. DELETING."
262 msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. Skipping."
267 def make_progress_graph(self) -> None:
268 if not self.spec.quiet:
269 from text_utils import progress_graph
271 progress_graph(self.file_done_count, self.total_file_count)
274 def read_input_files(self):
285 all_files = glob.glob(self.spec.file_glob)
286 self.total_file_count = len(all_files)
287 for files in list_utils.shard(all_files, 500):
288 file_list = list(files)
289 results.append(self.read_files_from_list(file_list))
291 for result in smart_future.wait_any(results, callback=self.make_progress_graph):
292 result = result._resolve()
297 if not self.spec.quiet:
298 print(" " * 80 + "\n")
302 def normalize_feature(value: str) -> Any:
303 if value in ("False", "None"):
305 elif value == "True":
307 elif isinstance(value, str) and "." in value:
308 ret = round(float(value) * 100.0)
314 def test_train_split(X, y) -> List:
315 logger.debug("Performing test/train split")
316 return train_test_split(
319 random_state=random.randrange(0, 1000),
323 def scale_data(X_train: np.ndarray, X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
324 logger.debug("Scaling data")
325 scaler = MinMaxScaler()
327 return (scaler, scaler.transform(X_train), scaler.transform(X_test))
329 # Note: children should implement. Consider using @parallelize.
331 def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
337 X_train_scaled: np.ndarray,
339 X_test_scaled: np.ndarray,
341 ) -> Tuple[np.float64, np.float64]:
342 logger.debug("Evaluating the model")
343 training_score = model.score(X_train_scaled, y_train) * 100.0
344 test_score = model.score(X_test_scaled, y_test) * 100.0
346 "Model evaluation results: test_score=%.5f, train_score=%.5f",
350 return (training_score, test_score)
352 def maybe_persist_scaler_and_model(
354 training_score: np.float64,
355 test_score: np.float64,
360 ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
361 if not self.spec.dry_run:
362 import datetime_utils
366 now: datetime.datetime = datetime_utils.now_pacific()
367 info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
368 Model params: {params}
369 Training examples: {num_examples}
370 Training set score: {training_score:.2f}%
371 Testing set score: {test_score:.2f}%"""
374 self.spec.persist_percentage_threshold is not None
375 and test_score > self.spec.persist_percentage_threshold
377 not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
379 scaler_filename = f"{self.spec.basename}_scaler.sav"
380 with open(scaler_filename, "wb") as fb:
381 pickle.dump(scaler, fb)
382 msg = f"Wrote {scaler_filename}"
385 model_filename = f"{self.spec.basename}_model.sav"
386 with open(model_filename, "wb") as fb:
387 pickle.dump(model, fb)
388 msg = f"Wrote {model_filename}"
391 model_info_filename = f"{self.spec.basename}_model_info.txt"
392 with open(model_info_filename, "w") as f:
394 msg = f"Wrote {model_info_filename}:"
397 print(string_utils.indent(info, 2))
399 return (scaler_filename, model_filename, model_info_filename)
400 return (None, None, None)