Initial revision
[python_utils.git] / ml_model_trainer.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import datetime
7 import glob
8 import logging
9 import os
10 import pickle
11 import random
12 import sys
13 from types import SimpleNamespace
14 from typing import Any, List, NamedTuple, Optional, Set, Tuple
15
16 import numpy as np
17 from sklearn.model_selection import train_test_split  # type:ignore
18 from sklearn.preprocessing import MinMaxScaler  # type: ignore
19
20 from ansi import bold, reset
21 import argparse_utils
22 import config
23 import datetime_utils
24 import decorator_utils
25 import input_utils
26 import list_utils
27 import parallelize as par
28 import smart_future
29 import string_utils
30 import text_utils
31
32 logger = logging.getLogger(__file__)
33
34 parser = config.add_commandline_args(
35     f"ML Model Trainer ({__file__})",
36     "Arguments related to training an ML model"
37 )
38 parser.add_argument(
39     "--ml_trainer_quiet",
40     action="store_true",
41     help="Don't prompt the user for anything."
42 )
43 parser.add_argument(
44     "--ml_trainer_delete",
45     action="store_true",
46     help="Delete invalid/incomplete features files in addition to warning."
47 )
48 group = parser.add_mutually_exclusive_group()
49 group.add_argument(
50     "--ml_trainer_dry_run",
51     action="store_true",
52     help="Do not write a new model, just report efficacy.",
53 )
54 group.add_argument(
55     "--ml_trainer_predicate",
56     type=argparse_utils.valid_percentage,
57     metavar='0..100',
58     help="Persist the model if the test set score is >= this predicate.",
59 )
60
61
62 class InputSpec(SimpleNamespace):
63     file_glob: str
64     feature_count: int
65     features_to_skip: Set[str]
66     key_value_delimiter: str
67     training_parameters: List
68     label: str
69     basename: str
70     dry_run: Optional[bool]
71     quiet: Optional[bool]
72     persist_predicate: Optional[float]
73     delete_bad_inputs: Optional[bool]
74
75     @staticmethod
76     def populate_from_config() -> InputSpec:
77         return InputSpec(
78             dry_run = config.config["ml_trainer_dry_run"],
79             quiet = config.config["ml_trainer_quiet"],
80             persist_predicate = config.config["ml_trainer_predicate"],
81             delete_bad_inputs = config.config["ml_trainer_delete"],
82         )
83
84
85 class OutputSpec(NamedTuple):
86     model_filename: Optional[str]
87     model_info_filename: Optional[str]
88     scaler_filename: Optional[str]
89     training_score: float
90     test_score: float
91
92
93 class TrainingBlueprint(ABC):
94     def __init__(self):
95         self.y_train = None
96         self.y_test = None
97         self.X_test_scaled = None
98         self.X_train_scaled = None
99         self.file_done_count = 0
100         self.total_file_count = 0
101         self.spec = None
102
103     def train(self, spec: InputSpec) -> OutputSpec:
104         random.seed()
105         self.spec = spec
106
107         X_, y_ = self.read_input_files()
108         num_examples = len(y_)
109
110         # Every example's features
111         X = np.array(X_)
112
113         # Every example's label
114         y = np.array(y_)
115
116         print("Doing random test/train split...")
117         X_train, X_test, self.y_train, self.y_test = self.test_train_split(
118             X,
119             y,
120         )
121
122         print("Scaling training data...")
123         scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
124             X_train,
125             X_test,
126         )
127
128         print("Training model(s)...")
129         models = []
130         modelid_to_params = {}
131         for params in self.spec.training_parameters:
132             model = self.train_model(
133                 params,
134                 self.X_train_scaled,
135                 self.y_train
136             )
137             models.append(model)
138             modelid_to_params[model.get_id()] = str(params)
139
140         best_model = None
141         best_score = None
142         best_test_score = None
143         best_training_score = None
144         best_params = None
145         for model in smart_future.wait_many(models):
146             params = modelid_to_params[model.get_id()]
147             if isinstance(model, smart_future.SmartFuture):
148                 model = model._resolve()
149             training_score, test_score = self.evaluate_model(
150                 model,
151                 self.X_train_scaled,
152                 self.y_train,
153                 self.X_test_scaled,
154                 self.y_test,
155             )
156             score = (training_score + test_score * 20) / 21
157             if not self.spec.quiet:
158                 print(
159                     f"{bold()}{params}{reset()}: "
160                     f"Training set score={training_score:.2f}%, "
161                     f"test set score={test_score:.2f}%",
162                     file=sys.stderr,
163                 )
164             if best_score is None or score > best_score:
165                 best_score = score
166                 best_test_score = test_score
167                 best_training_score = training_score
168                 best_model = model
169                 best_params = params
170                 if not self.spec.quiet:
171                     print(
172                         f"New best score {best_score:.2f}% with params {params}"
173                     )
174
175         if not self.spec.quiet:
176             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
177             print(msg)
178             logger.info(msg)
179         scaler_filename, model_filename, model_info_filename = (
180             self.maybe_persist_scaler_and_model(
181                 best_training_score,
182                 best_test_score,
183                 best_params,
184                 num_examples,
185                 scaler,
186                 best_model,
187             )
188         )
189         return OutputSpec(
190             model_filename = model_filename,
191             model_info_filename = model_info_filename,
192             scaler_filename = scaler_filename,
193             training_score = best_training_score,
194             test_score = best_test_score,
195         )
196
197     @par.parallelize(method=par.Method.THREAD)
198     def read_files_from_list(
199             self,
200             files: List[str],
201             n: int
202     ) -> Tuple[List, List]:
203         # All features
204         X = []
205
206         # The label
207         y = []
208
209         for filename in files:
210             wrote_label = False
211             with open(filename, "r") as f:
212                 lines = f.readlines()
213
214             # This example's features
215             x = []
216             for line in lines:
217
218                 # We expect lines in features files to be of the form:
219                 #
220                 # key: value
221                 line = line.strip()
222                 try:
223                     (key, value) = line.split(self.spec.key_value_delimiter)
224                 except Exception as e:
225                     logger.exception(e)
226                     print(f"WARNING: bad line '{line}', skipped")
227                     continue
228
229                 key = key.strip()
230                 value = value.strip()
231                 if (self.spec.features_to_skip is not None
232                         and key in self.spec.features_to_skip):
233                     logger.debug(f"Skipping feature {key}")
234                     continue
235
236                 value = self.normalize_feature(value)
237
238                 if key == self.spec.label:
239                     y.append(value)
240                     wrote_label = True
241                 else:
242                     x.append(value)
243
244             # Make sure we saw a label and the requisite number of features.
245             if len(x) == self.spec.feature_count and wrote_label:
246                 X.append(x)
247                 self.file_done_count += 1
248             else:
249                 if wrote_label:
250                     y.pop()
251
252                 if self.spec.delete_bad_inputs:
253                     msg = f"WARNING: {filename}: missing features or label.  DELETING."
254                     print(msg, file=sys.stderr)
255                     logger.warning(msg)
256                     os.remove(filename)
257                 else:
258                     msg = f"WARNING: {filename}: missing features or label.  Skipped."
259                     print(msg, file=sys.stderr)
260                     logger.warning(msg)
261         return (X, y)
262
263     def make_progress_graph(self) -> None:
264         if not self.spec.quiet:
265             text_utils.progress_graph(self.file_done_count,
266                                       self.total_file_count)
267
268     @decorator_utils.timed
269     def read_input_files(self):
270         # All features
271         X = []
272
273         # The label
274         y = []
275
276         results = []
277         all_files = glob.glob(self.spec.file_glob)
278         self.total_file_count = len(all_files)
279         for n, files in enumerate(list_utils.shard(all_files, 500)):
280             file_list = list(files)
281             results.append(self.read_files_from_list(file_list, n))
282
283         for result in smart_future.wait_many(results, callback=self.make_progress_graph):
284             result = result._resolve()
285             for z in result[0]:
286                 X.append(z)
287             for z in result[1]:
288                 y.append(z)
289         if not self.spec.quiet:
290             print(" " * 80 + "\n")
291         return (X, y)
292
293     def normalize_feature(self, value: str) -> Any:
294         if value in ("False", "None"):
295             ret = 0
296         elif value == "True":
297             ret = 255
298         elif isinstance(value, str) and "." in value:
299             ret = round(float(value) * 100.0)
300         else:
301             ret = int(value)
302         return ret
303
304     def test_train_split(self, X, y) -> List:
305         logger.debug("Performing test/train split")
306         return train_test_split(
307             X,
308             y,
309             random_state=random.randrange(0, 1000),
310         )
311
312     def scale_data(self,
313                    X_train: np.ndarray,
314                    X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
315         logger.debug("Scaling data")
316         scaler = MinMaxScaler()
317         scaler.fit(X_train)
318         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
319
320     # Note: children should implement.  Consider using @parallelize.
321     @abstractmethod
322     def train_model(self,
323                     parameters,
324                     X_train_scaled: np.ndarray,
325                     y_train: np.ndarray) -> Any:
326         pass
327
328     def evaluate_model(
329             self,
330             model: Any,
331             X_train_scaled: np.ndarray,
332             y_train: np.ndarray,
333             X_test_scaled: np.ndarray,
334             y_test: np.ndarray) -> Tuple[np.float64, np.float64]:
335         logger.debug("Evaluating the model")
336         training_score = model.score(X_train_scaled, y_train) * 100.0
337         test_score = model.score(X_test_scaled, y_test) * 100.0
338         logger.info(
339             f"Model evaluation results: test_score={test_score:.5f}, "
340             f"train_score={training_score:.5f}"
341         )
342         return (training_score, test_score)
343
344     def maybe_persist_scaler_and_model(
345             self,
346             training_score: np.float64,
347             test_score: np.float64,
348             params: str,
349             num_examples: int,
350             scaler: Any,
351             model: Any) -> Tuple[Optional[str], Optional[str], Optional[str]]:
352         if not self.spec.dry_run:
353             if (
354                     (self.spec.persist_predicate is not None and
355                      test_score > self.spec.persist_predicate)
356                     or
357                     (not self.spec.quiet
358                      and input_utils.yn_response("Write the model? [y,n]: ") == "y")
359             ):
360                 scaler_filename = f"{self.spec.basename}_scaler.sav"
361                 with open(scaler_filename, "wb") as f:
362                     pickle.dump(scaler, f)
363                 msg = f"Wrote {scaler_filename}"
364                 print(msg)
365                 logger.info(msg)
366                 model_filename = f"{self.spec.basename}_model.sav"
367                 with open(model_filename, "wb") as f:
368                     pickle.dump(model, f)
369                 msg = f"Wrote {model_filename}"
370                 print(msg)
371                 logger.info(msg)
372                 model_info_filename = f"{self.spec.basename}_model_info.txt"
373                 now: datetime.datetime = datetime_utils.now_pst()
374                 info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
375 Model params: {params}
376 Training examples: {num_examples}
377 Training set score: {training_score:.2f}%
378 Testing set score: {test_score:.2f}%"""
379                 with open(model_info_filename, "w") as f:
380                     f.write(info)
381                 msg = f"Wrote {model_info_filename}:"
382                 print(msg)
383                 logger.info(msg)
384                 print(string_utils.indent(info, 2))
385                 logger.info(info)
386                 return (scaler_filename, model_filename, model_info_filename)
387         return (None, None, None)