Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / ml / model_trainer.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a blueprint for training sklearn ML models."""
6
7 from __future__ import annotations
8 import datetime
9 import glob
10 import logging
11 import os
12 import pickle
13 import random
14 import sys
15 import warnings
16 from abc import ABC, abstractmethod
17 from dataclasses import dataclass
18 from types import SimpleNamespace
19 from typing import Any, List, Optional, Set, Tuple
20
21 import numpy as np
22 from sklearn.model_selection import train_test_split  # type:ignore
23 from sklearn.preprocessing import MinMaxScaler  # type: ignore
24
25 import argparse_utils
26 import config
27 import executors
28 import parallelize as par
29 from ansi import bold, reset
30 from decorator_utils import timed
31
32 logger = logging.getLogger(__name__)
33
34 parser = config.add_commandline_args(
35     f"ML Model Trainer ({__file__})",
36     "Arguments related to training an ML model",
37 )
38 parser.add_argument(
39     "--ml_trainer_quiet",
40     action="store_true",
41     help="Don't prompt the user for anything.",
42 )
43 parser.add_argument(
44     "--ml_trainer_delete",
45     action="store_true",
46     help="Delete invalid/incomplete features files in addition to warning.",
47 )
48 group = parser.add_mutually_exclusive_group()
49 group.add_argument(
50     "--ml_trainer_dry_run",
51     action="store_true",
52     help="Do not write a new model, just report efficacy.",
53 )
54 group.add_argument(
55     "--ml_trainer_persist_threshold",
56     type=argparse_utils.valid_percentage,
57     metavar='0..100',
58     help="Persist the model if the test set score is >= this threshold.",
59 )
60
61
62 class InputSpec(SimpleNamespace):
63     """A collection of info needed to train the model provided by the
64     caller."""
65
66     file_glob: str
67     feature_count: int
68     features_to_skip: Set[str]
69     key_value_delimiter: str
70     training_parameters: List
71     label: str
72     basename: str
73     dry_run: Optional[bool]
74     quiet: Optional[bool]
75     persist_percentage_threshold: Optional[float]
76     delete_bad_inputs: Optional[bool]
77
78     @staticmethod
79     def populate_from_config() -> InputSpec:
80         return InputSpec(
81             dry_run=config.config["ml_trainer_dry_run"],
82             quiet=config.config["ml_trainer_quiet"],
83             persist_percentage_threshold=config.config["ml_trainer_persist_threshold"],
84             delete_bad_inputs=config.config["ml_trainer_delete"],
85         )
86
87
88 @dataclass
89 class OutputSpec:
90     """Info about the results of training returned to the caller."""
91
92     model_filename: Optional[str] = None
93     model_info_filename: Optional[str] = None
94     scaler_filename: Optional[str] = None
95     training_score: np.float64 = np.float64(0.0)
96     test_score: np.float64 = np.float64(0.0)
97
98
99 class TrainingBlueprint(ABC):
100     """The blueprint for doing the actual training."""
101
102     def __init__(self):
103         self.y_train = None
104         self.y_test = None
105         self.X_test_scaled = None
106         self.X_train_scaled = None
107         self.file_done_count = 0
108         self.total_file_count = 0
109         self.spec = None
110
111     def train(self, spec: InputSpec) -> OutputSpec:
112         import smart_future
113
114         random.seed()
115         self.spec = spec
116
117         X_, y_ = self.read_input_files()
118         num_examples = len(y_)
119
120         # Every example's features
121         X = np.array(X_)
122
123         # Every example's label
124         y = np.array(y_)
125
126         print("Doing random test/train split...")
127         X_train, X_test, self.y_train, self.y_test = TrainingBlueprint.test_train_split(
128             X,
129             y,
130         )
131
132         print("Scaling training data...")
133         scaler, self.X_train_scaled, self.X_test_scaled = TrainingBlueprint.scale_data(
134             X_train,
135             X_test,
136         )
137
138         print("Training model(s)...")
139         models = []
140         modelid_to_params = {}
141         for params in self.spec.training_parameters:
142             model = self.train_model(params, self.X_train_scaled, self.y_train)
143             models.append(model)
144             modelid_to_params[model.get_id()] = str(params)
145
146         best_model = None
147         best_score: Optional[np.float64] = None
148         best_test_score: Optional[np.float64] = None
149         best_training_score: Optional[np.float64] = None
150         best_params = None
151         for model in smart_future.wait_any(models):
152             params = modelid_to_params[model.get_id()]
153             if isinstance(model, smart_future.SmartFuture):
154                 model = model._resolve()
155             if model is not None:
156                 training_score, test_score = TrainingBlueprint.evaluate_model(
157                     model,
158                     self.X_train_scaled,
159                     self.y_train,
160                     self.X_test_scaled,
161                     self.y_test,
162                 )
163                 score = (training_score + test_score * 20) / 21
164                 if not self.spec.quiet:
165                     print(
166                         f"{bold()}{params}{reset()}: "
167                         f"Training set score={training_score:.2f}%, "
168                         f"test set score={test_score:.2f}%",
169                         file=sys.stderr,
170                     )
171                 if best_score is None or score > best_score:
172                     best_score = score
173                     best_test_score = test_score
174                     best_training_score = training_score
175                     best_model = model
176                     best_params = params
177                     if not self.spec.quiet:
178                         print(f"New best score {best_score:.2f}% with params {params}")
179
180         if not self.spec.quiet:
181             executors.DefaultExecutors().shutdown()
182             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
183             print(msg)
184             logger.info(msg)
185
186         assert best_training_score is not None
187         assert best_test_score is not None
188         assert best_params is not None
189         (
190             scaler_filename,
191             model_filename,
192             model_info_filename,
193         ) = self.maybe_persist_scaler_and_model(
194             best_training_score,
195             best_test_score,
196             best_params,
197             num_examples,
198             scaler,
199             best_model,
200         )
201         return OutputSpec(
202             model_filename=model_filename,
203             model_info_filename=model_info_filename,
204             scaler_filename=scaler_filename,
205             training_score=best_training_score,
206             test_score=best_test_score,
207         )
208
209     @par.parallelize(method=par.Method.THREAD)
210     def read_files_from_list(self, files: List[str]) -> Tuple[List, List]:
211         # All features
212         X = []
213
214         # The label
215         y = []
216
217         for filename in files:
218             wrote_label = False
219             with open(filename, "r") as f:
220                 lines = f.readlines()
221
222             # This example's features
223             x = []
224             for line in lines:
225
226                 # We expect lines in features files to be of the form:
227                 #
228                 # key: value
229                 line = line.strip()
230                 try:
231                     (key, value) = line.split(self.spec.key_value_delimiter)
232                 except Exception:
233                     logger.debug("WARNING: bad line in file %s '%s', skipped", filename, line)
234                     continue
235
236                 key = key.strip()
237                 value = value.strip()
238                 if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
239                     logger.debug("Skipping feature %s", key)
240                     continue
241
242                 value = TrainingBlueprint.normalize_feature(value)
243
244                 if key == self.spec.label:
245                     y.append(value)
246                     wrote_label = True
247                 else:
248                     x.append(value)
249
250             # Make sure we saw a label and the requisite number of features.
251             if len(x) == self.spec.feature_count and wrote_label:
252                 X.append(x)
253                 self.file_done_count += 1
254             else:
255                 if wrote_label:
256                     y.pop()
257
258                 if self.spec.delete_bad_inputs:
259                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  DELETING."
260                     logger.warning(msg)
261                     warnings.warn(msg)
262                     os.remove(filename)
263                 else:
264                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  Skipping."
265                     logger.warning(msg)
266                     warnings.warn(msg)
267         return (X, y)
268
269     def make_progress_graph(self) -> None:
270         if not self.spec.quiet:
271             from text_utils import progress_graph
272
273             progress_graph(self.file_done_count, self.total_file_count)
274
275     @timed
276     def read_input_files(self):
277         import list_utils
278         import smart_future
279
280         # All features
281         X = []
282
283         # The label
284         y = []
285
286         results = []
287         all_files = glob.glob(self.spec.file_glob)
288         self.total_file_count = len(all_files)
289         for files in list_utils.shard(all_files, 500):
290             file_list = list(files)
291             results.append(self.read_files_from_list(file_list))
292
293         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
294             result = result._resolve()
295             for z in result[0]:
296                 X.append(z)
297             for z in result[1]:
298                 y.append(z)
299         if not self.spec.quiet:
300             print(" " * 80 + "\n")
301         return (X, y)
302
303     @staticmethod
304     def normalize_feature(value: str) -> Any:
305         if value in ("False", "None"):
306             ret = 0
307         elif value == "True":
308             ret = 255
309         elif isinstance(value, str) and "." in value:
310             ret = round(float(value) * 100.0)
311         else:
312             ret = int(value)
313         return ret
314
315     @staticmethod
316     def test_train_split(X, y) -> List:
317         logger.debug("Performing test/train split")
318         return train_test_split(
319             X,
320             y,
321             random_state=random.randrange(0, 1000),
322         )
323
324     @staticmethod
325     def scale_data(X_train: np.ndarray, X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
326         logger.debug("Scaling data")
327         scaler = MinMaxScaler()
328         scaler.fit(X_train)
329         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
330
331     # Note: children should implement.  Consider using @parallelize.
332     @abstractmethod
333     def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
334         pass
335
336     @staticmethod
337     def evaluate_model(
338         model: Any,
339         X_train_scaled: np.ndarray,
340         y_train: np.ndarray,
341         X_test_scaled: np.ndarray,
342         y_test: np.ndarray,
343     ) -> Tuple[np.float64, np.float64]:
344         logger.debug("Evaluating the model")
345         training_score = model.score(X_train_scaled, y_train) * 100.0
346         test_score = model.score(X_test_scaled, y_test) * 100.0
347         logger.info(
348             "Model evaluation results: test_score=%.5f, train_score=%.5f",
349             test_score,
350             training_score,
351         )
352         return (training_score, test_score)
353
354     def maybe_persist_scaler_and_model(
355         self,
356         training_score: np.float64,
357         test_score: np.float64,
358         params: str,
359         num_examples: int,
360         scaler: Any,
361         model: Any,
362     ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
363         if not self.spec.dry_run:
364             import datetime_utils
365             import input_utils
366             import string_utils
367
368             now: datetime.datetime = datetime_utils.now_pacific()
369             info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
370 Model params: {params}
371 Training examples: {num_examples}
372 Training set score: {training_score:.2f}%
373 Testing set score: {test_score:.2f}%"""
374             print(f'\n{info}\n')
375             if (
376                 self.spec.persist_percentage_threshold is not None
377                 and test_score > self.spec.persist_percentage_threshold
378             ) or (
379                 not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
380             ):
381                 scaler_filename = f"{self.spec.basename}_scaler.sav"
382                 with open(scaler_filename, "wb") as fb:
383                     pickle.dump(scaler, fb)
384                 msg = f"Wrote {scaler_filename}"
385                 print(msg)
386                 logger.info(msg)
387                 model_filename = f"{self.spec.basename}_model.sav"
388                 with open(model_filename, "wb") as fb:
389                     pickle.dump(model, fb)
390                 msg = f"Wrote {model_filename}"
391                 print(msg)
392                 logger.info(msg)
393                 model_info_filename = f"{self.spec.basename}_model_info.txt"
394                 with open(model_info_filename, "w") as f:
395                     f.write(info)
396                 msg = f"Wrote {model_info_filename}:"
397                 print(msg)
398                 logger.info(msg)
399                 print(string_utils.indent(info, 2))
400                 logger.info(info)
401                 return (scaler_filename, model_filename, model_info_filename)
402         return (None, None, None)