More cleanup.
[python_utils.git] / ml / model_trainer.py
1 #!/usr/bin/env python3
2
3 """This is a blueprint for training sklearn ML models."""
4
5 from __future__ import annotations
6 import datetime
7 import glob
8 import logging
9 import os
10 import pickle
11 import random
12 import sys
13 import warnings
14 from abc import ABC, abstractmethod
15 from dataclasses import dataclass
16 from types import SimpleNamespace
17 from typing import Any, List, Optional, Set, Tuple
18
19 import numpy as np
20 from sklearn.model_selection import train_test_split  # type:ignore
21 from sklearn.preprocessing import MinMaxScaler  # type: ignore
22
23 import argparse_utils
24 import config
25 import executors
26 import parallelize as par
27 from ansi import bold, reset
28 from decorator_utils import timed
29
30 logger = logging.getLogger(__name__)
31
32 parser = config.add_commandline_args(
33     f"ML Model Trainer ({__file__})",
34     "Arguments related to training an ML model",
35 )
36 parser.add_argument(
37     "--ml_trainer_quiet",
38     action="store_true",
39     help="Don't prompt the user for anything.",
40 )
41 parser.add_argument(
42     "--ml_trainer_delete",
43     action="store_true",
44     help="Delete invalid/incomplete features files in addition to warning.",
45 )
46 group = parser.add_mutually_exclusive_group()
47 group.add_argument(
48     "--ml_trainer_dry_run",
49     action="store_true",
50     help="Do not write a new model, just report efficacy.",
51 )
52 group.add_argument(
53     "--ml_trainer_persist_threshold",
54     type=argparse_utils.valid_percentage,
55     metavar='0..100',
56     help="Persist the model if the test set score is >= this threshold.",
57 )
58
59
60 class InputSpec(SimpleNamespace):
61     """A collection of info needed to train the model provided by the
62     caller."""
63
64     file_glob: str
65     feature_count: int
66     features_to_skip: Set[str]
67     key_value_delimiter: str
68     training_parameters: List
69     label: str
70     basename: str
71     dry_run: Optional[bool]
72     quiet: Optional[bool]
73     persist_percentage_threshold: Optional[float]
74     delete_bad_inputs: Optional[bool]
75
76     @staticmethod
77     def populate_from_config() -> InputSpec:
78         return InputSpec(
79             dry_run=config.config["ml_trainer_dry_run"],
80             quiet=config.config["ml_trainer_quiet"],
81             persist_percentage_threshold=config.config["ml_trainer_persist_threshold"],
82             delete_bad_inputs=config.config["ml_trainer_delete"],
83         )
84
85
86 @dataclass
87 class OutputSpec:
88     """Info about the results of training returned to the caller."""
89
90     model_filename: Optional[str] = None
91     model_info_filename: Optional[str] = None
92     scaler_filename: Optional[str] = None
93     training_score: np.float64 = np.float64(0.0)
94     test_score: np.float64 = np.float64(0.0)
95
96
97 class TrainingBlueprint(ABC):
98     """The blueprint for doing the actual training."""
99
100     def __init__(self):
101         self.y_train = None
102         self.y_test = None
103         self.X_test_scaled = None
104         self.X_train_scaled = None
105         self.file_done_count = 0
106         self.total_file_count = 0
107         self.spec = None
108
109     def train(self, spec: InputSpec) -> OutputSpec:
110         import smart_future
111
112         random.seed()
113         self.spec = spec
114
115         X_, y_ = self.read_input_files()
116         num_examples = len(y_)
117
118         # Every example's features
119         X = np.array(X_)
120
121         # Every example's label
122         y = np.array(y_)
123
124         print("Doing random test/train split...")
125         X_train, X_test, self.y_train, self.y_test = TrainingBlueprint.test_train_split(
126             X,
127             y,
128         )
129
130         print("Scaling training data...")
131         scaler, self.X_train_scaled, self.X_test_scaled = TrainingBlueprint.scale_data(
132             X_train,
133             X_test,
134         )
135
136         print("Training model(s)...")
137         models = []
138         modelid_to_params = {}
139         for params in self.spec.training_parameters:
140             model = self.train_model(params, self.X_train_scaled, self.y_train)
141             models.append(model)
142             modelid_to_params[model.get_id()] = str(params)
143
144         best_model = None
145         best_score: Optional[np.float64] = None
146         best_test_score: Optional[np.float64] = None
147         best_training_score: Optional[np.float64] = None
148         best_params = None
149         for model in smart_future.wait_any(models):
150             params = modelid_to_params[model.get_id()]
151             if isinstance(model, smart_future.SmartFuture):
152                 model = model._resolve()
153             if model is not None:
154                 training_score, test_score = TrainingBlueprint.evaluate_model(
155                     model,
156                     self.X_train_scaled,
157                     self.y_train,
158                     self.X_test_scaled,
159                     self.y_test,
160                 )
161                 score = (training_score + test_score * 20) / 21
162                 if not self.spec.quiet:
163                     print(
164                         f"{bold()}{params}{reset()}: "
165                         f"Training set score={training_score:.2f}%, "
166                         f"test set score={test_score:.2f}%",
167                         file=sys.stderr,
168                     )
169                 if best_score is None or score > best_score:
170                     best_score = score
171                     best_test_score = test_score
172                     best_training_score = training_score
173                     best_model = model
174                     best_params = params
175                     if not self.spec.quiet:
176                         print(f"New best score {best_score:.2f}% with params {params}")
177
178         if not self.spec.quiet:
179             executors.DefaultExecutors().shutdown()
180             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
181             print(msg)
182             logger.info(msg)
183
184         assert best_training_score is not None
185         assert best_test_score is not None
186         assert best_params is not None
187         (
188             scaler_filename,
189             model_filename,
190             model_info_filename,
191         ) = self.maybe_persist_scaler_and_model(
192             best_training_score,
193             best_test_score,
194             best_params,
195             num_examples,
196             scaler,
197             best_model,
198         )
199         return OutputSpec(
200             model_filename=model_filename,
201             model_info_filename=model_info_filename,
202             scaler_filename=scaler_filename,
203             training_score=best_training_score,
204             test_score=best_test_score,
205         )
206
207     @par.parallelize(method=par.Method.THREAD)
208     def read_files_from_list(self, files: List[str]) -> Tuple[List, List]:
209         # All features
210         X = []
211
212         # The label
213         y = []
214
215         for filename in files:
216             wrote_label = False
217             with open(filename, "r") as f:
218                 lines = f.readlines()
219
220             # This example's features
221             x = []
222             for line in lines:
223
224                 # We expect lines in features files to be of the form:
225                 #
226                 # key: value
227                 line = line.strip()
228                 try:
229                     (key, value) = line.split(self.spec.key_value_delimiter)
230                 except Exception:
231                     logger.debug("WARNING: bad line in file %s '%s', skipped", filename, line)
232                     continue
233
234                 key = key.strip()
235                 value = value.strip()
236                 if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
237                     logger.debug("Skipping feature %s", key)
238                     continue
239
240                 value = TrainingBlueprint.normalize_feature(value)
241
242                 if key == self.spec.label:
243                     y.append(value)
244                     wrote_label = True
245                 else:
246                     x.append(value)
247
248             # Make sure we saw a label and the requisite number of features.
249             if len(x) == self.spec.feature_count and wrote_label:
250                 X.append(x)
251                 self.file_done_count += 1
252             else:
253                 if wrote_label:
254                     y.pop()
255
256                 if self.spec.delete_bad_inputs:
257                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  DELETING."
258                     logger.warning(msg)
259                     warnings.warn(msg)
260                     os.remove(filename)
261                 else:
262                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  Skipping."
263                     logger.warning(msg)
264                     warnings.warn(msg)
265         return (X, y)
266
267     def make_progress_graph(self) -> None:
268         if not self.spec.quiet:
269             from text_utils import progress_graph
270
271             progress_graph(self.file_done_count, self.total_file_count)
272
273     @timed
274     def read_input_files(self):
275         import list_utils
276         import smart_future
277
278         # All features
279         X = []
280
281         # The label
282         y = []
283
284         results = []
285         all_files = glob.glob(self.spec.file_glob)
286         self.total_file_count = len(all_files)
287         for files in list_utils.shard(all_files, 500):
288             file_list = list(files)
289             results.append(self.read_files_from_list(file_list))
290
291         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
292             result = result._resolve()
293             for z in result[0]:
294                 X.append(z)
295             for z in result[1]:
296                 y.append(z)
297         if not self.spec.quiet:
298             print(" " * 80 + "\n")
299         return (X, y)
300
301     @staticmethod
302     def normalize_feature(value: str) -> Any:
303         if value in ("False", "None"):
304             ret = 0
305         elif value == "True":
306             ret = 255
307         elif isinstance(value, str) and "." in value:
308             ret = round(float(value) * 100.0)
309         else:
310             ret = int(value)
311         return ret
312
313     @staticmethod
314     def test_train_split(X, y) -> List:
315         logger.debug("Performing test/train split")
316         return train_test_split(
317             X,
318             y,
319             random_state=random.randrange(0, 1000),
320         )
321
322     @staticmethod
323     def scale_data(X_train: np.ndarray, X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
324         logger.debug("Scaling data")
325         scaler = MinMaxScaler()
326         scaler.fit(X_train)
327         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
328
329     # Note: children should implement.  Consider using @parallelize.
330     @abstractmethod
331     def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
332         pass
333
334     @staticmethod
335     def evaluate_model(
336         model: Any,
337         X_train_scaled: np.ndarray,
338         y_train: np.ndarray,
339         X_test_scaled: np.ndarray,
340         y_test: np.ndarray,
341     ) -> Tuple[np.float64, np.float64]:
342         logger.debug("Evaluating the model")
343         training_score = model.score(X_train_scaled, y_train) * 100.0
344         test_score = model.score(X_test_scaled, y_test) * 100.0
345         logger.info(
346             "Model evaluation results: test_score=%.5f, train_score=%.5f",
347             test_score,
348             training_score,
349         )
350         return (training_score, test_score)
351
352     def maybe_persist_scaler_and_model(
353         self,
354         training_score: np.float64,
355         test_score: np.float64,
356         params: str,
357         num_examples: int,
358         scaler: Any,
359         model: Any,
360     ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
361         if not self.spec.dry_run:
362             import datetime_utils
363             import input_utils
364             import string_utils
365
366             now: datetime.datetime = datetime_utils.now_pacific()
367             info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
368 Model params: {params}
369 Training examples: {num_examples}
370 Training set score: {training_score:.2f}%
371 Testing set score: {test_score:.2f}%"""
372             print(f'\n{info}\n')
373             if (
374                 self.spec.persist_percentage_threshold is not None
375                 and test_score > self.spec.persist_percentage_threshold
376             ) or (
377                 not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
378             ):
379                 scaler_filename = f"{self.spec.basename}_scaler.sav"
380                 with open(scaler_filename, "wb") as fb:
381                     pickle.dump(scaler, fb)
382                 msg = f"Wrote {scaler_filename}"
383                 print(msg)
384                 logger.info(msg)
385                 model_filename = f"{self.spec.basename}_model.sav"
386                 with open(model_filename, "wb") as fb:
387                     pickle.dump(model, fb)
388                 msg = f"Wrote {model_filename}"
389                 print(msg)
390                 logger.info(msg)
391                 model_info_filename = f"{self.spec.basename}_model_info.txt"
392                 with open(model_info_filename, "w") as f:
393                     f.write(info)
394                 msg = f"Wrote {model_info_filename}:"
395                 print(msg)
396                 logger.info(msg)
397                 print(string_utils.indent(info, 2))
398                 logger.info(info)
399                 return (scaler_filename, model_filename, model_info_filename)
400         return (None, None, None)