Change settings in flake8 and black.
[python_utils.git] / ml / model_trainer.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 import datetime
6 import glob
7 import logging
8 import os
9 import pickle
10 import random
11 import sys
12 import warnings
13 from abc import ABC, abstractmethod
14 from types import SimpleNamespace
15 from typing import Any, List, NamedTuple, Optional, Set, Tuple
16
17 import numpy as np
18 from sklearn.model_selection import train_test_split  # type:ignore
19 from sklearn.preprocessing import MinMaxScaler  # type: ignore
20
21 import argparse_utils
22 import config
23 import executors
24 import parallelize as par
25 from ansi import bold, reset
26 from decorator_utils import timed
27
28 logger = logging.getLogger(__file__)
29
30 parser = config.add_commandline_args(
31     f"ML Model Trainer ({__file__})",
32     "Arguments related to training an ML model",
33 )
34 parser.add_argument(
35     "--ml_trainer_quiet",
36     action="store_true",
37     help="Don't prompt the user for anything.",
38 )
39 parser.add_argument(
40     "--ml_trainer_delete",
41     action="store_true",
42     help="Delete invalid/incomplete features files in addition to warning.",
43 )
44 group = parser.add_mutually_exclusive_group()
45 group.add_argument(
46     "--ml_trainer_dry_run",
47     action="store_true",
48     help="Do not write a new model, just report efficacy.",
49 )
50 group.add_argument(
51     "--ml_trainer_persist_threshold",
52     type=argparse_utils.valid_percentage,
53     metavar='0..100',
54     help="Persist the model if the test set score is >= this threshold.",
55 )
56
57
58 class InputSpec(SimpleNamespace):
59     file_glob: str
60     feature_count: int
61     features_to_skip: Set[str]
62     key_value_delimiter: str
63     training_parameters: List
64     label: str
65     basename: str
66     dry_run: Optional[bool]
67     quiet: Optional[bool]
68     persist_percentage_threshold: Optional[float]
69     delete_bad_inputs: Optional[bool]
70
71     @staticmethod
72     def populate_from_config() -> InputSpec:
73         return InputSpec(
74             dry_run=config.config["ml_trainer_dry_run"],
75             quiet=config.config["ml_trainer_quiet"],
76             persist_percentage_threshold=config.config["ml_trainer_persist_threshold"],
77             delete_bad_inputs=config.config["ml_trainer_delete"],
78         )
79
80
81 class OutputSpec(NamedTuple):
82     model_filename: Optional[str]
83     model_info_filename: Optional[str]
84     scaler_filename: Optional[str]
85     training_score: np.float64
86     test_score: np.float64
87
88
89 class TrainingBlueprint(ABC):
90     def __init__(self):
91         self.y_train = None
92         self.y_test = None
93         self.X_test_scaled = None
94         self.X_train_scaled = None
95         self.file_done_count = 0
96         self.total_file_count = 0
97         self.spec = None
98
99     def train(self, spec: InputSpec) -> OutputSpec:
100         import smart_future
101
102         random.seed()
103         self.spec = spec
104
105         X_, y_ = self.read_input_files()
106         num_examples = len(y_)
107
108         # Every example's features
109         X = np.array(X_)
110
111         # Every example's label
112         y = np.array(y_)
113
114         print("Doing random test/train split...")
115         X_train, X_test, self.y_train, self.y_test = self.test_train_split(
116             X,
117             y,
118         )
119
120         print("Scaling training data...")
121         scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
122             X_train,
123             X_test,
124         )
125
126         print("Training model(s)...")
127         models = []
128         modelid_to_params = {}
129         for params in self.spec.training_parameters:
130             model = self.train_model(params, self.X_train_scaled, self.y_train)
131             models.append(model)
132             modelid_to_params[model.get_id()] = str(params)
133
134         best_model = None
135         best_score: Optional[np.float64] = None
136         best_test_score: Optional[np.float64] = None
137         best_training_score: Optional[np.float64] = None
138         best_params = None
139         for model in smart_future.wait_any(models):
140             params = modelid_to_params[model.get_id()]
141             if isinstance(model, smart_future.SmartFuture):
142                 model = model._resolve()
143             if model is not None:
144                 training_score, test_score = self.evaluate_model(
145                     model,
146                     self.X_train_scaled,
147                     self.y_train,
148                     self.X_test_scaled,
149                     self.y_test,
150                 )
151                 score = (training_score + test_score * 20) / 21
152                 if not self.spec.quiet:
153                     print(
154                         f"{bold()}{params}{reset()}: "
155                         f"Training set score={training_score:.2f}%, "
156                         f"test set score={test_score:.2f}%",
157                         file=sys.stderr,
158                     )
159                 if best_score is None or score > best_score:
160                     best_score = score
161                     best_test_score = test_score
162                     best_training_score = training_score
163                     best_model = model
164                     best_params = params
165                     if not self.spec.quiet:
166                         print(f"New best score {best_score:.2f}% with params {params}")
167
168         if not self.spec.quiet:
169             executors.DefaultExecutors().shutdown()
170             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
171             print(msg)
172             logger.info(msg)
173
174         assert best_training_score is not None
175         assert best_test_score is not None
176         assert best_params is not None
177         (
178             scaler_filename,
179             model_filename,
180             model_info_filename,
181         ) = self.maybe_persist_scaler_and_model(
182             best_training_score,
183             best_test_score,
184             best_params,
185             num_examples,
186             scaler,
187             best_model,
188         )
189         return OutputSpec(
190             model_filename=model_filename,
191             model_info_filename=model_info_filename,
192             scaler_filename=scaler_filename,
193             training_score=best_training_score,
194             test_score=best_test_score,
195         )
196
197     @par.parallelize(method=par.Method.THREAD)
198     def read_files_from_list(self, files: List[str], n: int) -> Tuple[List, List]:
199         # All features
200         X = []
201
202         # The label
203         y = []
204
205         for filename in files:
206             wrote_label = False
207             with open(filename, "r") as f:
208                 lines = f.readlines()
209
210             # This example's features
211             x = []
212             for line in lines:
213
214                 # We expect lines in features files to be of the form:
215                 #
216                 # key: value
217                 line = line.strip()
218                 try:
219                     (key, value) = line.split(self.spec.key_value_delimiter)
220                 except Exception:
221                     logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
222                     continue
223
224                 key = key.strip()
225                 value = value.strip()
226                 if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
227                     logger.debug(f"Skipping feature {key}")
228                     continue
229
230                 value = self.normalize_feature(value)
231
232                 if key == self.spec.label:
233                     y.append(value)
234                     wrote_label = True
235                 else:
236                     x.append(value)
237
238             # Make sure we saw a label and the requisite number of features.
239             if len(x) == self.spec.feature_count and wrote_label:
240                 X.append(x)
241                 self.file_done_count += 1
242             else:
243                 if wrote_label:
244                     y.pop()
245
246                 if self.spec.delete_bad_inputs:
247                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  DELETING."
248                     logger.warning(msg)
249                     warnings.warn(msg)
250                     os.remove(filename)
251                 else:
252                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  Skipping."
253                     logger.warning(msg)
254                     warnings.warn(msg)
255         return (X, y)
256
257     def make_progress_graph(self) -> None:
258         if not self.spec.quiet:
259             from text_utils import progress_graph
260
261             progress_graph(self.file_done_count, self.total_file_count)
262
263     @timed
264     def read_input_files(self):
265         import list_utils
266         import smart_future
267
268         # All features
269         X = []
270
271         # The label
272         y = []
273
274         results = []
275         all_files = glob.glob(self.spec.file_glob)
276         self.total_file_count = len(all_files)
277         for n, files in enumerate(list_utils.shard(all_files, 500)):
278             file_list = list(files)
279             results.append(self.read_files_from_list(file_list, n))
280
281         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
282             result = result._resolve()
283             for z in result[0]:
284                 X.append(z)
285             for z in result[1]:
286                 y.append(z)
287         if not self.spec.quiet:
288             print(" " * 80 + "\n")
289         return (X, y)
290
291     def normalize_feature(self, value: str) -> Any:
292         if value in ("False", "None"):
293             ret = 0
294         elif value == "True":
295             ret = 255
296         elif isinstance(value, str) and "." in value:
297             ret = round(float(value) * 100.0)
298         else:
299             ret = int(value)
300         return ret
301
302     def test_train_split(self, X, y) -> List:
303         logger.debug("Performing test/train split")
304         return train_test_split(
305             X,
306             y,
307             random_state=random.randrange(0, 1000),
308         )
309
310     def scale_data(
311         self, X_train: np.ndarray, X_test: np.ndarray
312     ) -> Tuple[Any, np.ndarray, np.ndarray]:
313         logger.debug("Scaling data")
314         scaler = MinMaxScaler()
315         scaler.fit(X_train)
316         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
317
318     # Note: children should implement.  Consider using @parallelize.
319     @abstractmethod
320     def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
321         pass
322
323     def evaluate_model(
324         self,
325         model: Any,
326         X_train_scaled: np.ndarray,
327         y_train: np.ndarray,
328         X_test_scaled: np.ndarray,
329         y_test: np.ndarray,
330     ) -> Tuple[np.float64, np.float64]:
331         logger.debug("Evaluating the model")
332         training_score = model.score(X_train_scaled, y_train) * 100.0
333         test_score = model.score(X_test_scaled, y_test) * 100.0
334         logger.info(
335             f"Model evaluation results: test_score={test_score:.5f}, "
336             f"train_score={training_score:.5f}"
337         )
338         return (training_score, test_score)
339
340     def maybe_persist_scaler_and_model(
341         self,
342         training_score: np.float64,
343         test_score: np.float64,
344         params: str,
345         num_examples: int,
346         scaler: Any,
347         model: Any,
348     ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
349         if not self.spec.dry_run:
350             import datetime_utils
351             import input_utils
352             import string_utils
353
354             now: datetime.datetime = datetime_utils.now_pacific()
355             info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
356 Model params: {params}
357 Training examples: {num_examples}
358 Training set score: {training_score:.2f}%
359 Testing set score: {test_score:.2f}%"""
360             print(f'\n{info}\n')
361             if (
362                 self.spec.persist_percentage_threshold is not None
363                 and test_score > self.spec.persist_percentage_threshold
364             ) or (
365                 not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
366             ):
367                 scaler_filename = f"{self.spec.basename}_scaler.sav"
368                 with open(scaler_filename, "wb") as fb:
369                     pickle.dump(scaler, fb)
370                 msg = f"Wrote {scaler_filename}"
371                 print(msg)
372                 logger.info(msg)
373                 model_filename = f"{self.spec.basename}_model.sav"
374                 with open(model_filename, "wb") as fb:
375                     pickle.dump(model, fb)
376                 msg = f"Wrote {model_filename}"
377                 print(msg)
378                 logger.info(msg)
379                 model_info_filename = f"{self.spec.basename}_model_info.txt"
380                 with open(model_info_filename, "w") as f:
381                     f.write(info)
382                 msg = f"Wrote {model_info_filename}:"
383                 print(msg)
384                 logger.info(msg)
385                 print(string_utils.indent(info, 2))
386                 logger.info(info)
387                 return (scaler_filename, model_filename, model_info_filename)
388         return (None, None, None)