3 from __future__ import annotations
5 from abc import ABC, abstractmethod
13 from types import SimpleNamespace
14 from typing import Any, List, NamedTuple, Optional, Set, Tuple
17 from sklearn.model_selection import train_test_split # type:ignore
18 from sklearn.preprocessing import MinMaxScaler # type: ignore
20 from ansi import bold, reset
24 import decorator_utils
27 import parallelize as par
32 logger = logging.getLogger(__file__)
34 parser = config.add_commandline_args(
35 f"ML Model Trainer ({__file__})",
36 "Arguments related to training an ML model"
41 help="Don't prompt the user for anything."
44 "--ml_trainer_delete",
46 help="Delete invalid/incomplete features files in addition to warning."
48 group = parser.add_mutually_exclusive_group()
50 "--ml_trainer_dry_run",
52 help="Do not write a new model, just report efficacy.",
55 "--ml_trainer_predicate",
56 type=argparse_utils.valid_percentage,
58 help="Persist the model if the test set score is >= this predicate.",
62 class InputSpec(SimpleNamespace):
65 features_to_skip: Set[str]
66 key_value_delimiter: str
67 training_parameters: List
70 dry_run: Optional[bool]
72 persist_predicate: Optional[float]
73 delete_bad_inputs: Optional[bool]
76 def populate_from_config() -> InputSpec:
78 dry_run = config.config["ml_trainer_dry_run"],
79 quiet = config.config["ml_trainer_quiet"],
80 persist_predicate = config.config["ml_trainer_predicate"],
81 delete_bad_inputs = config.config["ml_trainer_delete"],
85 class OutputSpec(NamedTuple):
86 model_filename: Optional[str]
87 model_info_filename: Optional[str]
88 scaler_filename: Optional[str]
93 class TrainingBlueprint(ABC):
97 self.X_test_scaled = None
98 self.X_train_scaled = None
99 self.file_done_count = 0
100 self.total_file_count = 0
103 def train(self, spec: InputSpec) -> OutputSpec:
107 X_, y_ = self.read_input_files()
108 num_examples = len(y_)
110 # Every example's features
113 # Every example's label
116 print("Doing random test/train split...")
117 X_train, X_test, self.y_train, self.y_test = self.test_train_split(
122 print("Scaling training data...")
123 scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
128 print("Training model(s)...")
130 modelid_to_params = {}
131 for params in self.spec.training_parameters:
132 model = self.train_model(
138 modelid_to_params[model.get_id()] = str(params)
142 best_test_score = None
143 best_training_score = None
145 for model in smart_future.wait_many(models):
146 params = modelid_to_params[model.get_id()]
147 if isinstance(model, smart_future.SmartFuture):
148 model = model._resolve()
149 training_score, test_score = self.evaluate_model(
156 score = (training_score + test_score * 20) / 21
157 if not self.spec.quiet:
159 f"{bold()}{params}{reset()}: "
160 f"Training set score={training_score:.2f}%, "
161 f"test set score={test_score:.2f}%",
164 if best_score is None or score > best_score:
166 best_test_score = test_score
167 best_training_score = training_score
170 if not self.spec.quiet:
172 f"New best score {best_score:.2f}% with params {params}"
175 if not self.spec.quiet:
176 msg = f"Done training; best test set score was: {best_test_score:.1f}%"
179 scaler_filename, model_filename, model_info_filename = (
180 self.maybe_persist_scaler_and_model(
190 model_filename = model_filename,
191 model_info_filename = model_info_filename,
192 scaler_filename = scaler_filename,
193 training_score = best_training_score,
194 test_score = best_test_score,
197 @par.parallelize(method=par.Method.THREAD)
198 def read_files_from_list(
202 ) -> Tuple[List, List]:
209 for filename in files:
211 with open(filename, "r") as f:
212 lines = f.readlines()
214 # This example's features
218 # We expect lines in features files to be of the form:
223 (key, value) = line.split(self.spec.key_value_delimiter)
224 except Exception as e:
226 print(f"WARNING: bad line '{line}', skipped")
230 value = value.strip()
231 if (self.spec.features_to_skip is not None
232 and key in self.spec.features_to_skip):
233 logger.debug(f"Skipping feature {key}")
236 value = self.normalize_feature(value)
238 if key == self.spec.label:
244 # Make sure we saw a label and the requisite number of features.
245 if len(x) == self.spec.feature_count and wrote_label:
247 self.file_done_count += 1
252 if self.spec.delete_bad_inputs:
253 msg = f"WARNING: {filename}: missing features or label. DELETING."
254 print(msg, file=sys.stderr)
258 msg = f"WARNING: {filename}: missing features or label. Skipped."
259 print(msg, file=sys.stderr)
263 def make_progress_graph(self) -> None:
264 if not self.spec.quiet:
265 text_utils.progress_graph(self.file_done_count,
266 self.total_file_count)
268 @decorator_utils.timed
269 def read_input_files(self):
277 all_files = glob.glob(self.spec.file_glob)
278 self.total_file_count = len(all_files)
279 for n, files in enumerate(list_utils.shard(all_files, 500)):
280 file_list = list(files)
281 results.append(self.read_files_from_list(file_list, n))
283 for result in smart_future.wait_many(results, callback=self.make_progress_graph):
284 result = result._resolve()
289 if not self.spec.quiet:
290 print(" " * 80 + "\n")
293 def normalize_feature(self, value: str) -> Any:
294 if value in ("False", "None"):
296 elif value == "True":
298 elif isinstance(value, str) and "." in value:
299 ret = round(float(value) * 100.0)
304 def test_train_split(self, X, y) -> List:
305 logger.debug("Performing test/train split")
306 return train_test_split(
309 random_state=random.randrange(0, 1000),
314 X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
315 logger.debug("Scaling data")
316 scaler = MinMaxScaler()
318 return (scaler, scaler.transform(X_train), scaler.transform(X_test))
320 # Note: children should implement. Consider using @parallelize.
322 def train_model(self,
324 X_train_scaled: np.ndarray,
325 y_train: np.ndarray) -> Any:
331 X_train_scaled: np.ndarray,
333 X_test_scaled: np.ndarray,
334 y_test: np.ndarray) -> Tuple[np.float64, np.float64]:
335 logger.debug("Evaluating the model")
336 training_score = model.score(X_train_scaled, y_train) * 100.0
337 test_score = model.score(X_test_scaled, y_test) * 100.0
339 f"Model evaluation results: test_score={test_score:.5f}, "
340 f"train_score={training_score:.5f}"
342 return (training_score, test_score)
344 def maybe_persist_scaler_and_model(
346 training_score: np.float64,
347 test_score: np.float64,
351 model: Any) -> Tuple[Optional[str], Optional[str], Optional[str]]:
352 if not self.spec.dry_run:
354 (self.spec.persist_predicate is not None and
355 test_score > self.spec.persist_predicate)
358 and input_utils.yn_response("Write the model? [y,n]: ") == "y")
360 scaler_filename = f"{self.spec.basename}_scaler.sav"
361 with open(scaler_filename, "wb") as f:
362 pickle.dump(scaler, f)
363 msg = f"Wrote {scaler_filename}"
366 model_filename = f"{self.spec.basename}_model.sav"
367 with open(model_filename, "wb") as f:
368 pickle.dump(model, f)
369 msg = f"Wrote {model_filename}"
372 model_info_filename = f"{self.spec.basename}_model_info.txt"
373 now: datetime.datetime = datetime_utils.now_pst()
374 info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
375 Model params: {params}
376 Training examples: {num_examples}
377 Training set score: {training_score:.2f}%
378 Testing set score: {test_score:.2f}%"""
379 with open(model_info_filename, "w") as f:
381 msg = f"Wrote {model_info_filename}:"
384 print(string_utils.indent(info, 2))
386 return (scaler_filename, model_filename, model_info_filename)
387 return (None, None, None)