Make subdirs type clean too.
[python_utils.git] / ml / model_trainer.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 import datetime
6 import glob
7 import logging
8 import os
9 import pickle
10 import random
11 import sys
12 import warnings
13 from abc import ABC, abstractmethod
14 from types import SimpleNamespace
15 from typing import Any, List, NamedTuple, Optional, Set, Tuple
16
17 import numpy as np
18 from sklearn.model_selection import train_test_split  # type:ignore
19 from sklearn.preprocessing import MinMaxScaler  # type: ignore
20
21 import argparse_utils
22 import config
23 import executors
24 import parallelize as par
25 from ansi import bold, reset
26 from decorator_utils import timed
27
28 logger = logging.getLogger(__file__)
29
30 parser = config.add_commandline_args(
31     f"ML Model Trainer ({__file__})", "Arguments related to training an ML model"
32 )
33 parser.add_argument(
34     "--ml_trainer_quiet",
35     action="store_true",
36     help="Don't prompt the user for anything.",
37 )
38 parser.add_argument(
39     "--ml_trainer_delete",
40     action="store_true",
41     help="Delete invalid/incomplete features files in addition to warning.",
42 )
43 group = parser.add_mutually_exclusive_group()
44 group.add_argument(
45     "--ml_trainer_dry_run",
46     action="store_true",
47     help="Do not write a new model, just report efficacy.",
48 )
49 group.add_argument(
50     "--ml_trainer_persist_threshold",
51     type=argparse_utils.valid_percentage,
52     metavar='0..100',
53     help="Persist the model if the test set score is >= this threshold.",
54 )
55
56
57 class InputSpec(SimpleNamespace):
58     file_glob: str
59     feature_count: int
60     features_to_skip: Set[str]
61     key_value_delimiter: str
62     training_parameters: List
63     label: str
64     basename: str
65     dry_run: Optional[bool]
66     quiet: Optional[bool]
67     persist_percentage_threshold: Optional[float]
68     delete_bad_inputs: Optional[bool]
69
70     @staticmethod
71     def populate_from_config() -> InputSpec:
72         return InputSpec(
73             dry_run=config.config["ml_trainer_dry_run"],
74             quiet=config.config["ml_trainer_quiet"],
75             persist_percentage_threshold=config.config["ml_trainer_persist_threshold"],
76             delete_bad_inputs=config.config["ml_trainer_delete"],
77         )
78
79
80 class OutputSpec(NamedTuple):
81     model_filename: Optional[str]
82     model_info_filename: Optional[str]
83     scaler_filename: Optional[str]
84     training_score: np.float64
85     test_score: np.float64
86
87
88 class TrainingBlueprint(ABC):
89     def __init__(self):
90         self.y_train = None
91         self.y_test = None
92         self.X_test_scaled = None
93         self.X_train_scaled = None
94         self.file_done_count = 0
95         self.total_file_count = 0
96         self.spec = None
97
98     def train(self, spec: InputSpec) -> OutputSpec:
99         import smart_future
100
101         random.seed()
102         self.spec = spec
103
104         X_, y_ = self.read_input_files()
105         num_examples = len(y_)
106
107         # Every example's features
108         X = np.array(X_)
109
110         # Every example's label
111         y = np.array(y_)
112
113         print("Doing random test/train split...")
114         X_train, X_test, self.y_train, self.y_test = self.test_train_split(
115             X,
116             y,
117         )
118
119         print("Scaling training data...")
120         scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
121             X_train,
122             X_test,
123         )
124
125         print("Training model(s)...")
126         models = []
127         modelid_to_params = {}
128         for params in self.spec.training_parameters:
129             model = self.train_model(params, self.X_train_scaled, self.y_train)
130             models.append(model)
131             modelid_to_params[model.get_id()] = str(params)
132
133         best_model = None
134         best_score: Optional[np.float64] = None
135         best_test_score: Optional[np.float64] = None
136         best_training_score: Optional[np.float64] = None
137         best_params = None
138         for model in smart_future.wait_any(models):
139             params = modelid_to_params[model.get_id()]
140             if isinstance(model, smart_future.SmartFuture):
141                 model = model._resolve()
142             if model is not None:
143                 training_score, test_score = self.evaluate_model(
144                     model,
145                     self.X_train_scaled,
146                     self.y_train,
147                     self.X_test_scaled,
148                     self.y_test,
149                 )
150                 score = (training_score + test_score * 20) / 21
151                 if not self.spec.quiet:
152                     print(
153                         f"{bold()}{params}{reset()}: "
154                         f"Training set score={training_score:.2f}%, "
155                         f"test set score={test_score:.2f}%",
156                         file=sys.stderr,
157                     )
158                 if best_score is None or score > best_score:
159                     best_score = score
160                     best_test_score = test_score
161                     best_training_score = training_score
162                     best_model = model
163                     best_params = params
164                     if not self.spec.quiet:
165                         print(f"New best score {best_score:.2f}% with params {params}")
166
167         if not self.spec.quiet:
168             executors.DefaultExecutors().shutdown()
169             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
170             print(msg)
171             logger.info(msg)
172
173         assert best_training_score
174         assert best_test_score
175         assert best_params
176         (
177             scaler_filename,
178             model_filename,
179             model_info_filename,
180         ) = self.maybe_persist_scaler_and_model(
181             best_training_score,
182             best_test_score,
183             best_params,
184             num_examples,
185             scaler,
186             best_model,
187         )
188         return OutputSpec(
189             model_filename=model_filename,
190             model_info_filename=model_info_filename,
191             scaler_filename=scaler_filename,
192             training_score=best_training_score,
193             test_score=best_test_score,
194         )
195
196     @par.parallelize(method=par.Method.THREAD)
197     def read_files_from_list(self, files: List[str], n: int) -> Tuple[List, List]:
198         # All features
199         X = []
200
201         # The label
202         y = []
203
204         for filename in files:
205             wrote_label = False
206             with open(filename, "r") as f:
207                 lines = f.readlines()
208
209             # This example's features
210             x = []
211             for line in lines:
212
213                 # We expect lines in features files to be of the form:
214                 #
215                 # key: value
216                 line = line.strip()
217                 try:
218                     (key, value) = line.split(self.spec.key_value_delimiter)
219                 except Exception:
220                     logger.debug(
221                         f"WARNING: bad line in file {filename} '{line}', skipped"
222                     )
223                     continue
224
225                 key = key.strip()
226                 value = value.strip()
227                 if (
228                     self.spec.features_to_skip is not None
229                     and key in self.spec.features_to_skip
230                 ):
231                     logger.debug(f"Skipping feature {key}")
232                     continue
233
234                 value = self.normalize_feature(value)
235
236                 if key == self.spec.label:
237                     y.append(value)
238                     wrote_label = True
239                 else:
240                     x.append(value)
241
242             # Make sure we saw a label and the requisite number of features.
243             if len(x) == self.spec.feature_count and wrote_label:
244                 X.append(x)
245                 self.file_done_count += 1
246             else:
247                 if wrote_label:
248                     y.pop()
249
250                 if self.spec.delete_bad_inputs:
251                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  DELETING."
252                     logger.warning(msg)
253                     warnings.warn(msg)
254                     os.remove(filename)
255                 else:
256                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  Skipping."
257                     logger.warning(msg)
258                     warnings.warn(msg)
259         return (X, y)
260
261     def make_progress_graph(self) -> None:
262         if not self.spec.quiet:
263             from text_utils import progress_graph
264
265             progress_graph(self.file_done_count, self.total_file_count)
266
267     @timed
268     def read_input_files(self):
269         import list_utils
270         import smart_future
271
272         # All features
273         X = []
274
275         # The label
276         y = []
277
278         results = []
279         all_files = glob.glob(self.spec.file_glob)
280         self.total_file_count = len(all_files)
281         for n, files in enumerate(list_utils.shard(all_files, 500)):
282             file_list = list(files)
283             results.append(self.read_files_from_list(file_list, n))
284
285         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
286             result = result._resolve()
287             for z in result[0]:
288                 X.append(z)
289             for z in result[1]:
290                 y.append(z)
291         if not self.spec.quiet:
292             print(" " * 80 + "\n")
293         return (X, y)
294
295     def normalize_feature(self, value: str) -> Any:
296         if value in ("False", "None"):
297             ret = 0
298         elif value == "True":
299             ret = 255
300         elif isinstance(value, str) and "." in value:
301             ret = round(float(value) * 100.0)
302         else:
303             ret = int(value)
304         return ret
305
306     def test_train_split(self, X, y) -> List:
307         logger.debug("Performing test/train split")
308         return train_test_split(
309             X,
310             y,
311             random_state=random.randrange(0, 1000),
312         )
313
314     def scale_data(
315         self, X_train: np.ndarray, X_test: np.ndarray
316     ) -> Tuple[Any, np.ndarray, np.ndarray]:
317         logger.debug("Scaling data")
318         scaler = MinMaxScaler()
319         scaler.fit(X_train)
320         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
321
322     # Note: children should implement.  Consider using @parallelize.
323     @abstractmethod
324     def train_model(
325         self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray
326     ) -> Any:
327         pass
328
329     def evaluate_model(
330         self,
331         model: Any,
332         X_train_scaled: np.ndarray,
333         y_train: np.ndarray,
334         X_test_scaled: np.ndarray,
335         y_test: np.ndarray,
336     ) -> Tuple[np.float64, np.float64]:
337         logger.debug("Evaluating the model")
338         training_score = model.score(X_train_scaled, y_train) * 100.0
339         test_score = model.score(X_test_scaled, y_test) * 100.0
340         logger.info(
341             f"Model evaluation results: test_score={test_score:.5f}, "
342             f"train_score={training_score:.5f}"
343         )
344         return (training_score, test_score)
345
346     def maybe_persist_scaler_and_model(
347         self,
348         training_score: np.float64,
349         test_score: np.float64,
350         params: str,
351         num_examples: int,
352         scaler: Any,
353         model: Any,
354     ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
355         if not self.spec.dry_run:
356             import datetime_utils
357             import input_utils
358             import string_utils
359
360             now: datetime.datetime = datetime_utils.now_pacific()
361             info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
362 Model params: {params}
363 Training examples: {num_examples}
364 Training set score: {training_score:.2f}%
365 Testing set score: {test_score:.2f}%"""
366             print(f'\n{info}\n')
367             if (
368                 self.spec.persist_percentage_threshold is not None
369                 and test_score > self.spec.persist_percentage_threshold
370             ) or (
371                 not self.spec.quiet
372                 and input_utils.yn_response("Write the model? [y,n]: ") == "y"
373             ):
374                 scaler_filename = f"{self.spec.basename}_scaler.sav"
375                 with open(scaler_filename, "wb") as fb:
376                     pickle.dump(scaler, fb)
377                 msg = f"Wrote {scaler_filename}"
378                 print(msg)
379                 logger.info(msg)
380                 model_filename = f"{self.spec.basename}_model.sav"
381                 with open(model_filename, "wb") as fb:
382                     pickle.dump(model, fb)
383                 msg = f"Wrote {model_filename}"
384                 print(msg)
385                 logger.info(msg)
386                 model_info_filename = f"{self.spec.basename}_model_info.txt"
387                 with open(model_info_filename, "w") as f:
388                     f.write(info)
389                 msg = f"Wrote {model_info_filename}:"
390                 print(msg)
391                 logger.info(msg)
392                 print(string_utils.indent(info, 2))
393                 logger.info(info)
394                 return (scaler_filename, model_filename, model_info_filename)
395         return (None, None, None)