Start using warnings from stdlib.
[python_utils.git] / ml / model_trainer.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import datetime
7 import glob
8 import logging
9 import os
10 import pickle
11 import random
12 import sys
13 from types import SimpleNamespace
14 from typing import Any, List, NamedTuple, Optional, Set, Tuple
15 import warnings
16
17 import numpy as np
18 from sklearn.model_selection import train_test_split  # type:ignore
19 from sklearn.preprocessing import MinMaxScaler  # type: ignore
20
21 from ansi import bold, reset
22 import argparse_utils
23 import config
24 from decorator_utils import timed
25 import parallelize as par
26
27 logger = logging.getLogger(__file__)
28
29 parser = config.add_commandline_args(
30     f"ML Model Trainer ({__file__})",
31     "Arguments related to training an ML model"
32 )
33 parser.add_argument(
34     "--ml_trainer_quiet",
35     action="store_true",
36     help="Don't prompt the user for anything."
37 )
38 parser.add_argument(
39     "--ml_trainer_delete",
40     action="store_true",
41     help="Delete invalid/incomplete features files in addition to warning."
42 )
43 group = parser.add_mutually_exclusive_group()
44 group.add_argument(
45     "--ml_trainer_dry_run",
46     action="store_true",
47     help="Do not write a new model, just report efficacy.",
48 )
49 group.add_argument(
50     "--ml_trainer_persist_threshold",
51     type=argparse_utils.valid_percentage,
52     metavar='0..100',
53     help="Persist the model if the test set score is >= this threshold.",
54 )
55
56
57 class InputSpec(SimpleNamespace):
58     file_glob: str
59     feature_count: int
60     features_to_skip: Set[str]
61     key_value_delimiter: str
62     training_parameters: List
63     label: str
64     basename: str
65     dry_run: Optional[bool]
66     quiet: Optional[bool]
67     persist_percentage_threshold: Optional[float]
68     delete_bad_inputs: Optional[bool]
69
70     @staticmethod
71     def populate_from_config() -> InputSpec:
72         return InputSpec(
73             dry_run = config.config["ml_trainer_dry_run"],
74             quiet = config.config["ml_trainer_quiet"],
75             persist_percentage_threshold = config.config["ml_trainer_persist_threshold"],
76             delete_bad_inputs = config.config["ml_trainer_delete"],
77         )
78
79
80 class OutputSpec(NamedTuple):
81     model_filename: Optional[str]
82     model_info_filename: Optional[str]
83     scaler_filename: Optional[str]
84     training_score: float
85     test_score: float
86
87
88 class TrainingBlueprint(ABC):
89     def __init__(self):
90         self.y_train = None
91         self.y_test = None
92         self.X_test_scaled = None
93         self.X_train_scaled = None
94         self.file_done_count = 0
95         self.total_file_count = 0
96         self.spec = None
97
98     def train(self, spec: InputSpec) -> OutputSpec:
99         import smart_future
100
101         random.seed()
102         self.spec = spec
103
104         X_, y_ = self.read_input_files()
105         num_examples = len(y_)
106
107         # Every example's features
108         X = np.array(X_)
109
110         # Every example's label
111         y = np.array(y_)
112
113         print("Doing random test/train split...")
114         X_train, X_test, self.y_train, self.y_test = self.test_train_split(
115             X,
116             y,
117         )
118
119         print("Scaling training data...")
120         scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
121             X_train,
122             X_test,
123         )
124
125         print("Training model(s)...")
126         models = []
127         modelid_to_params = {}
128         for params in self.spec.training_parameters:
129             model = self.train_model(
130                 params,
131                 self.X_train_scaled,
132                 self.y_train
133             )
134             models.append(model)
135             modelid_to_params[model.get_id()] = str(params)
136
137         best_model = None
138         best_score = None
139         best_test_score = None
140         best_training_score = None
141         best_params = None
142         for model in smart_future.wait_any(models):
143             params = modelid_to_params[model.get_id()]
144             if isinstance(model, smart_future.SmartFuture):
145                 model = model._resolve()
146             if model is not None:
147                 training_score, test_score = self.evaluate_model(
148                     model,
149                     self.X_train_scaled,
150                     self.y_train,
151                     self.X_test_scaled,
152                     self.y_test,
153                 )
154                 score = (training_score + test_score * 20) / 21
155                 if not self.spec.quiet:
156                     print(
157                         f"{bold()}{params}{reset()}: "
158                         f"Training set score={training_score:.2f}%, "
159                         f"test set score={test_score:.2f}%",
160                         file=sys.stderr,
161                     )
162                 if best_score is None or score > best_score:
163                     best_score = score
164                     best_test_score = test_score
165                     best_training_score = training_score
166                     best_model = model
167                     best_params = params
168                     if not self.spec.quiet:
169                         print(
170                             f"New best score {best_score:.2f}% with params {params}"
171                         )
172
173         if not self.spec.quiet:
174             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
175             print(msg)
176             logger.info(msg)
177         scaler_filename, model_filename, model_info_filename = (
178             self.maybe_persist_scaler_and_model(
179                 best_training_score,
180                 best_test_score,
181                 best_params,
182                 num_examples,
183                 scaler,
184                 best_model,
185             )
186         )
187         return OutputSpec(
188             model_filename = model_filename,
189             model_info_filename = model_info_filename,
190             scaler_filename = scaler_filename,
191             training_score = best_training_score,
192             test_score = best_test_score,
193         )
194
195     @par.parallelize(method=par.Method.THREAD)
196     def read_files_from_list(
197             self,
198             files: List[str],
199             n: int
200     ) -> Tuple[List, List]:
201         # All features
202         X = []
203
204         # The label
205         y = []
206
207         for filename in files:
208             wrote_label = False
209             with open(filename, "r") as f:
210                 lines = f.readlines()
211
212             # This example's features
213             x = []
214             for line in lines:
215
216                 # We expect lines in features files to be of the form:
217                 #
218                 # key: value
219                 line = line.strip()
220                 try:
221                     (key, value) = line.split(self.spec.key_value_delimiter)
222                 except Exception:
223                     logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
224                     continue
225
226                 key = key.strip()
227                 value = value.strip()
228                 if (self.spec.features_to_skip is not None
229                         and key in self.spec.features_to_skip):
230                     logger.debug(f"Skipping feature {key}")
231                     continue
232
233                 value = self.normalize_feature(value)
234
235                 if key == self.spec.label:
236                     y.append(value)
237                     wrote_label = True
238                 else:
239                     x.append(value)
240
241             # Make sure we saw a label and the requisite number of features.
242             if len(x) == self.spec.feature_count and wrote_label:
243                 X.append(x)
244                 self.file_done_count += 1
245             else:
246                 if wrote_label:
247                     y.pop()
248
249                 if self.spec.delete_bad_inputs:
250                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  DELETING."
251                     logger.warning(msg)
252                     warnings.warn(msg)
253                     os.remove(filename)
254                 else:
255                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  Skipping."
256                     logger.warning(msg)
257                     warnings.warn(msg)
258         return (X, y)
259
260     def make_progress_graph(self) -> None:
261         if not self.spec.quiet:
262             from text_utils import progress_graph
263             progress_graph(
264                 self.file_done_count,
265                 self.total_file_count
266             )
267
268     @timed
269     def read_input_files(self):
270         import list_utils
271         import smart_future
272
273         # All features
274         X = []
275
276         # The label
277         y = []
278
279         results = []
280         all_files = glob.glob(self.spec.file_glob)
281         self.total_file_count = len(all_files)
282         for n, files in enumerate(list_utils.shard(all_files, 500)):
283             file_list = list(files)
284             results.append(self.read_files_from_list(file_list, n))
285
286         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
287             result = result._resolve()
288             for z in result[0]:
289                 X.append(z)
290             for z in result[1]:
291                 y.append(z)
292         if not self.spec.quiet:
293             print(" " * 80 + "\n")
294         return (X, y)
295
296     def normalize_feature(self, value: str) -> Any:
297         if value in ("False", "None"):
298             ret = 0
299         elif value == "True":
300             ret = 255
301         elif isinstance(value, str) and "." in value:
302             ret = round(float(value) * 100.0)
303         else:
304             ret = int(value)
305         return ret
306
307     def test_train_split(self, X, y) -> List:
308         logger.debug("Performing test/train split")
309         return train_test_split(
310             X,
311             y,
312             random_state=random.randrange(0, 1000),
313         )
314
315     def scale_data(self,
316                    X_train: np.ndarray,
317                    X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
318         logger.debug("Scaling data")
319         scaler = MinMaxScaler()
320         scaler.fit(X_train)
321         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
322
323     # Note: children should implement.  Consider using @parallelize.
324     @abstractmethod
325     def train_model(self,
326                     parameters,
327                     X_train_scaled: np.ndarray,
328                     y_train: np.ndarray) -> Any:
329         pass
330
331     def evaluate_model(
332             self,
333             model: Any,
334             X_train_scaled: np.ndarray,
335             y_train: np.ndarray,
336             X_test_scaled: np.ndarray,
337             y_test: np.ndarray) -> Tuple[np.float64, np.float64]:
338         logger.debug("Evaluating the model")
339         training_score = model.score(X_train_scaled, y_train) * 100.0
340         test_score = model.score(X_test_scaled, y_test) * 100.0
341         logger.info(
342             f"Model evaluation results: test_score={test_score:.5f}, "
343             f"train_score={training_score:.5f}"
344         )
345         return (training_score, test_score)
346
347     def maybe_persist_scaler_and_model(
348             self,
349             training_score: np.float64,
350             test_score: np.float64,
351             params: str,
352             num_examples: int,
353             scaler: Any,
354             model: Any) -> Tuple[Optional[str], Optional[str], Optional[str]]:
355         if not self.spec.dry_run:
356             import datetime_utils
357             import input_utils
358             import string_utils
359
360             now: datetime.datetime = datetime_utils.now_pacific()
361             info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
362 Model params: {params}
363 Training examples: {num_examples}
364 Training set score: {training_score:.2f}%
365 Testing set score: {test_score:.2f}%"""
366             print(f'\n{info}\n')
367             if (
368                     (self.spec.persist_percentage_threshold is not None and
369                      test_score > self.spec.persist_percentage_threshold)
370                     or
371                     (not self.spec.quiet
372                      and input_utils.yn_response("Write the model? [y,n]: ") == "y")
373             ):
374                 scaler_filename = f"{self.spec.basename}_scaler.sav"
375                 with open(scaler_filename, "wb") as f:
376                     pickle.dump(scaler, f)
377                 msg = f"Wrote {scaler_filename}"
378                 print(msg)
379                 logger.info(msg)
380                 model_filename = f"{self.spec.basename}_model.sav"
381                 with open(model_filename, "wb") as f:
382                     pickle.dump(model, f)
383                 msg = f"Wrote {model_filename}"
384                 print(msg)
385                 logger.info(msg)
386                 model_info_filename = f"{self.spec.basename}_model_info.txt"
387                 with open(model_info_filename, "w") as f:
388                     f.write(info)
389                 msg = f"Wrote {model_info_filename}:"
390                 print(msg)
391                 logger.info(msg)
392                 print(string_utils.indent(info, 2))
393                 logger.info(info)
394                 return (scaler_filename, model_filename, model_info_filename)
395         return (None, None, None)