Use ValueError, not Exception.
[python_utils.git] / ml_model_trainer.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import datetime
7 import glob
8 import logging
9 import os
10 import pickle
11 import random
12 import sys
13 from types import SimpleNamespace
14 from typing import Any, List, NamedTuple, Optional, Set, Tuple
15
16 import numpy as np
17 from sklearn.model_selection import train_test_split  # type:ignore
18 from sklearn.preprocessing import MinMaxScaler  # type: ignore
19
20 from ansi import bold, reset
21 import argparse_utils
22 import config
23 import datetime_utils
24 import decorator_utils
25 import input_utils
26 import list_utils
27 import parallelize as par
28 import smart_future
29 import string_utils
30 import text_utils
31
32 logger = logging.getLogger(__file__)
33
34 parser = config.add_commandline_args(
35     f"ML Model Trainer ({__file__})",
36     "Arguments related to training an ML model"
37 )
38 parser.add_argument(
39     "--ml_trainer_quiet",
40     action="store_true",
41     help="Don't prompt the user for anything."
42 )
43 parser.add_argument(
44     "--ml_trainer_delete",
45     action="store_true",
46     help="Delete invalid/incomplete features files in addition to warning."
47 )
48 group = parser.add_mutually_exclusive_group()
49 group.add_argument(
50     "--ml_trainer_dry_run",
51     action="store_true",
52     help="Do not write a new model, just report efficacy.",
53 )
54 group.add_argument(
55     "--ml_trainer_persist_threshold",
56     type=argparse_utils.valid_percentage,
57     metavar='0..100',
58     help="Persist the model if the test set score is >= this threshold.",
59 )
60
61
62 class InputSpec(SimpleNamespace):
63     file_glob: str
64     feature_count: int
65     features_to_skip: Set[str]
66     key_value_delimiter: str
67     training_parameters: List
68     label: str
69     basename: str
70     dry_run: Optional[bool]
71     quiet: Optional[bool]
72     persist_percentage_threshold: Optional[float]
73     delete_bad_inputs: Optional[bool]
74
75     @staticmethod
76     def populate_from_config() -> InputSpec:
77         return InputSpec(
78             dry_run = config.config["ml_trainer_dry_run"],
79             quiet = config.config["ml_trainer_quiet"],
80             persist_percentage_threshold = config.config["ml_trainer_persist_threshold"],
81             delete_bad_inputs = config.config["ml_trainer_delete"],
82         )
83
84
85 class OutputSpec(NamedTuple):
86     model_filename: Optional[str]
87     model_info_filename: Optional[str]
88     scaler_filename: Optional[str]
89     training_score: float
90     test_score: float
91
92
93 class TrainingBlueprint(ABC):
94     def __init__(self):
95         self.y_train = None
96         self.y_test = None
97         self.X_test_scaled = None
98         self.X_train_scaled = None
99         self.file_done_count = 0
100         self.total_file_count = 0
101         self.spec = None
102
103     def train(self, spec: InputSpec) -> OutputSpec:
104         random.seed()
105         self.spec = spec
106
107         X_, y_ = self.read_input_files()
108         num_examples = len(y_)
109
110         # Every example's features
111         X = np.array(X_)
112
113         # Every example's label
114         y = np.array(y_)
115
116         print("Doing random test/train split...")
117         X_train, X_test, self.y_train, self.y_test = self.test_train_split(
118             X,
119             y,
120         )
121
122         print("Scaling training data...")
123         scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
124             X_train,
125             X_test,
126         )
127
128         print("Training model(s)...")
129         models = []
130         modelid_to_params = {}
131         for params in self.spec.training_parameters:
132             model = self.train_model(
133                 params,
134                 self.X_train_scaled,
135                 self.y_train
136             )
137             models.append(model)
138             modelid_to_params[model.get_id()] = str(params)
139
140         best_model = None
141         best_score = None
142         best_test_score = None
143         best_training_score = None
144         best_params = None
145         for model in smart_future.wait_many(models):
146             params = modelid_to_params[model.get_id()]
147             if isinstance(model, smart_future.SmartFuture):
148                 model = model._resolve()
149             training_score, test_score = self.evaluate_model(
150                 model,
151                 self.X_train_scaled,
152                 self.y_train,
153                 self.X_test_scaled,
154                 self.y_test,
155             )
156             score = (training_score + test_score * 20) / 21
157             if not self.spec.quiet:
158                 print(
159                     f"{bold()}{params}{reset()}: "
160                     f"Training set score={training_score:.2f}%, "
161                     f"test set score={test_score:.2f}%",
162                     file=sys.stderr,
163                 )
164             if best_score is None or score > best_score:
165                 best_score = score
166                 best_test_score = test_score
167                 best_training_score = training_score
168                 best_model = model
169                 best_params = params
170                 if not self.spec.quiet:
171                     print(
172                         f"New best score {best_score:.2f}% with params {params}"
173                     )
174
175         if not self.spec.quiet:
176             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
177             print(msg)
178             logger.info(msg)
179         scaler_filename, model_filename, model_info_filename = (
180             self.maybe_persist_scaler_and_model(
181                 best_training_score,
182                 best_test_score,
183                 best_params,
184                 num_examples,
185                 scaler,
186                 best_model,
187             )
188         )
189         return OutputSpec(
190             model_filename = model_filename,
191             model_info_filename = model_info_filename,
192             scaler_filename = scaler_filename,
193             training_score = best_training_score,
194             test_score = best_test_score,
195         )
196
197     @par.parallelize(method=par.Method.THREAD)
198     def read_files_from_list(
199             self,
200             files: List[str],
201             n: int
202     ) -> Tuple[List, List]:
203         # All features
204         X = []
205
206         # The label
207         y = []
208
209         for filename in files:
210             wrote_label = False
211             with open(filename, "r") as f:
212                 lines = f.readlines()
213
214             # This example's features
215             x = []
216             for line in lines:
217
218                 # We expect lines in features files to be of the form:
219                 #
220                 # key: value
221                 line = line.strip()
222                 try:
223                     (key, value) = line.split(self.spec.key_value_delimiter)
224                 except Exception as e:
225                     logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
226                     continue
227
228                 key = key.strip()
229                 value = value.strip()
230                 if (self.spec.features_to_skip is not None
231                         and key in self.spec.features_to_skip):
232                     logger.debug(f"Skipping feature {key}")
233                     continue
234
235                 value = self.normalize_feature(value)
236
237                 if key == self.spec.label:
238                     y.append(value)
239                     wrote_label = True
240                 else:
241                     x.append(value)
242
243             # Make sure we saw a label and the requisite number of features.
244             if len(x) == self.spec.feature_count and wrote_label:
245                 X.append(x)
246                 self.file_done_count += 1
247             else:
248                 if wrote_label:
249                     y.pop()
250
251                 if self.spec.delete_bad_inputs:
252                     msg = f"WARNING: {filename}: missing features or label.  DELETING."
253                     print(msg, file=sys.stderr)
254                     logger.warning(msg)
255                     os.remove(filename)
256                 else:
257                     msg = f"WARNING: {filename}: missing features or label.  Skipped."
258                     print(msg, file=sys.stderr)
259                     logger.warning(msg)
260         return (X, y)
261
262     def make_progress_graph(self) -> None:
263         if not self.spec.quiet:
264             text_utils.progress_graph(self.file_done_count,
265                                       self.total_file_count)
266
267     @decorator_utils.timed
268     def read_input_files(self):
269         # All features
270         X = []
271
272         # The label
273         y = []
274
275         results = []
276         all_files = glob.glob(self.spec.file_glob)
277         self.total_file_count = len(all_files)
278         for n, files in enumerate(list_utils.shard(all_files, 500)):
279             file_list = list(files)
280             results.append(self.read_files_from_list(file_list, n))
281
282         for result in smart_future.wait_many(results, callback=self.make_progress_graph):
283             result = result._resolve()
284             for z in result[0]:
285                 X.append(z)
286             for z in result[1]:
287                 y.append(z)
288         if not self.spec.quiet:
289             print(" " * 80 + "\n")
290         return (X, y)
291
292     def normalize_feature(self, value: str) -> Any:
293         if value in ("False", "None"):
294             ret = 0
295         elif value == "True":
296             ret = 255
297         elif isinstance(value, str) and "." in value:
298             ret = round(float(value) * 100.0)
299         else:
300             ret = int(value)
301         return ret
302
303     def test_train_split(self, X, y) -> List:
304         logger.debug("Performing test/train split")
305         return train_test_split(
306             X,
307             y,
308             random_state=random.randrange(0, 1000),
309         )
310
311     def scale_data(self,
312                    X_train: np.ndarray,
313                    X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
314         logger.debug("Scaling data")
315         scaler = MinMaxScaler()
316         scaler.fit(X_train)
317         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
318
319     # Note: children should implement.  Consider using @parallelize.
320     @abstractmethod
321     def train_model(self,
322                     parameters,
323                     X_train_scaled: np.ndarray,
324                     y_train: np.ndarray) -> Any:
325         pass
326
327     def evaluate_model(
328             self,
329             model: Any,
330             X_train_scaled: np.ndarray,
331             y_train: np.ndarray,
332             X_test_scaled: np.ndarray,
333             y_test: np.ndarray) -> Tuple[np.float64, np.float64]:
334         logger.debug("Evaluating the model")
335         training_score = model.score(X_train_scaled, y_train) * 100.0
336         test_score = model.score(X_test_scaled, y_test) * 100.0
337         logger.info(
338             f"Model evaluation results: test_score={test_score:.5f}, "
339             f"train_score={training_score:.5f}"
340         )
341         return (training_score, test_score)
342
343     def maybe_persist_scaler_and_model(
344             self,
345             training_score: np.float64,
346             test_score: np.float64,
347             params: str,
348             num_examples: int,
349             scaler: Any,
350             model: Any) -> Tuple[Optional[str], Optional[str], Optional[str]]:
351         if not self.spec.dry_run:
352             if (
353                     (self.spec.persist_percentage_threshold is not None and
354                      test_score > self.spec.persist_percentage_threshold)
355                     or
356                     (not self.spec.quiet
357                      and input_utils.yn_response("Write the model? [y,n]: ") == "y")
358             ):
359                 scaler_filename = f"{self.spec.basename}_scaler.sav"
360                 with open(scaler_filename, "wb") as f:
361                     pickle.dump(scaler, f)
362                 msg = f"Wrote {scaler_filename}"
363                 print(msg)
364                 logger.info(msg)
365                 model_filename = f"{self.spec.basename}_model.sav"
366                 with open(model_filename, "wb") as f:
367                     pickle.dump(model, f)
368                 msg = f"Wrote {model_filename}"
369                 print(msg)
370                 logger.info(msg)
371                 model_info_filename = f"{self.spec.basename}_model_info.txt"
372                 now: datetime.datetime = datetime_utils.now_pst()
373                 info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
374 Model params: {params}
375 Training examples: {num_examples}
376 Training set score: {training_score:.2f}%
377 Testing set score: {test_score:.2f}%"""
378                 with open(model_info_filename, "w") as f:
379                     f.write(info)
380                 msg = f"Wrote {model_info_filename}:"
381                 print(msg)
382                 logger.info(msg)
383                 print(string_utils.indent(info, 2))
384                 logger.info(info)
385                 return (scaler_filename, model_filename, model_info_filename)
386         return (None, None, None)