Minor cleanup.
[python_utils.git] / ml / model_trainer.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import datetime
7 import glob
8 import logging
9 import os
10 import pickle
11 import random
12 import sys
13 from types import SimpleNamespace
14 from typing import Any, List, NamedTuple, Optional, Set, Tuple
15 import warnings
16
17 import numpy as np
18 from sklearn.model_selection import train_test_split  # type:ignore
19 from sklearn.preprocessing import MinMaxScaler  # type: ignore
20
21 from ansi import bold, reset
22 import argparse_utils
23 import config
24 from decorator_utils import timed
25 import executors
26 import parallelize as par
27
28 logger = logging.getLogger(__file__)
29
30 parser = config.add_commandline_args(
31     f"ML Model Trainer ({__file__})", "Arguments related to training an ML model"
32 )
33 parser.add_argument(
34     "--ml_trainer_quiet",
35     action="store_true",
36     help="Don't prompt the user for anything.",
37 )
38 parser.add_argument(
39     "--ml_trainer_delete",
40     action="store_true",
41     help="Delete invalid/incomplete features files in addition to warning.",
42 )
43 group = parser.add_mutually_exclusive_group()
44 group.add_argument(
45     "--ml_trainer_dry_run",
46     action="store_true",
47     help="Do not write a new model, just report efficacy.",
48 )
49 group.add_argument(
50     "--ml_trainer_persist_threshold",
51     type=argparse_utils.valid_percentage,
52     metavar='0..100',
53     help="Persist the model if the test set score is >= this threshold.",
54 )
55
56
57 class InputSpec(SimpleNamespace):
58     file_glob: str
59     feature_count: int
60     features_to_skip: Set[str]
61     key_value_delimiter: str
62     training_parameters: List
63     label: str
64     basename: str
65     dry_run: Optional[bool]
66     quiet: Optional[bool]
67     persist_percentage_threshold: Optional[float]
68     delete_bad_inputs: Optional[bool]
69
70     @staticmethod
71     def populate_from_config() -> InputSpec:
72         return InputSpec(
73             dry_run=config.config["ml_trainer_dry_run"],
74             quiet=config.config["ml_trainer_quiet"],
75             persist_percentage_threshold=config.config["ml_trainer_persist_threshold"],
76             delete_bad_inputs=config.config["ml_trainer_delete"],
77         )
78
79
80 class OutputSpec(NamedTuple):
81     model_filename: Optional[str]
82     model_info_filename: Optional[str]
83     scaler_filename: Optional[str]
84     training_score: float
85     test_score: float
86
87
88 class TrainingBlueprint(ABC):
89     def __init__(self):
90         self.y_train = None
91         self.y_test = None
92         self.X_test_scaled = None
93         self.X_train_scaled = None
94         self.file_done_count = 0
95         self.total_file_count = 0
96         self.spec = None
97
98     def train(self, spec: InputSpec) -> OutputSpec:
99         import smart_future
100
101         random.seed()
102         self.spec = spec
103
104         X_, y_ = self.read_input_files()
105         num_examples = len(y_)
106
107         # Every example's features
108         X = np.array(X_)
109
110         # Every example's label
111         y = np.array(y_)
112
113         print("Doing random test/train split...")
114         X_train, X_test, self.y_train, self.y_test = self.test_train_split(
115             X,
116             y,
117         )
118
119         print("Scaling training data...")
120         scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
121             X_train,
122             X_test,
123         )
124
125         print("Training model(s)...")
126         models = []
127         modelid_to_params = {}
128         for params in self.spec.training_parameters:
129             model = self.train_model(params, self.X_train_scaled, self.y_train)
130             models.append(model)
131             modelid_to_params[model.get_id()] = str(params)
132
133         best_model = None
134         best_score = None
135         best_test_score = None
136         best_training_score = None
137         best_params = None
138         for model in smart_future.wait_any(models):
139             params = modelid_to_params[model.get_id()]
140             if isinstance(model, smart_future.SmartFuture):
141                 model = model._resolve()
142             if model is not None:
143                 training_score, test_score = self.evaluate_model(
144                     model,
145                     self.X_train_scaled,
146                     self.y_train,
147                     self.X_test_scaled,
148                     self.y_test,
149                 )
150                 score = (training_score + test_score * 20) / 21
151                 if not self.spec.quiet:
152                     print(
153                         f"{bold()}{params}{reset()}: "
154                         f"Training set score={training_score:.2f}%, "
155                         f"test set score={test_score:.2f}%",
156                         file=sys.stderr,
157                     )
158                 if best_score is None or score > best_score:
159                     best_score = score
160                     best_test_score = test_score
161                     best_training_score = training_score
162                     best_model = model
163                     best_params = params
164                     if not self.spec.quiet:
165                         print(f"New best score {best_score:.2f}% with params {params}")
166
167         if not self.spec.quiet:
168             executors.DefaultExecutors().shutdown()
169             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
170             print(msg)
171             logger.info(msg)
172
173         (
174             scaler_filename,
175             model_filename,
176             model_info_filename,
177         ) = self.maybe_persist_scaler_and_model(
178             best_training_score,
179             best_test_score,
180             best_params,
181             num_examples,
182             scaler,
183             best_model,
184         )
185         return OutputSpec(
186             model_filename=model_filename,
187             model_info_filename=model_info_filename,
188             scaler_filename=scaler_filename,
189             training_score=best_training_score,
190             test_score=best_test_score,
191         )
192
193     @par.parallelize(method=par.Method.THREAD)
194     def read_files_from_list(self, files: List[str], n: int) -> Tuple[List, List]:
195         # All features
196         X = []
197
198         # The label
199         y = []
200
201         for filename in files:
202             wrote_label = False
203             with open(filename, "r") as f:
204                 lines = f.readlines()
205
206             # This example's features
207             x = []
208             for line in lines:
209
210                 # We expect lines in features files to be of the form:
211                 #
212                 # key: value
213                 line = line.strip()
214                 try:
215                     (key, value) = line.split(self.spec.key_value_delimiter)
216                 except Exception:
217                     logger.debug(
218                         f"WARNING: bad line in file {filename} '{line}', skipped"
219                     )
220                     continue
221
222                 key = key.strip()
223                 value = value.strip()
224                 if (
225                     self.spec.features_to_skip is not None
226                     and key in self.spec.features_to_skip
227                 ):
228                     logger.debug(f"Skipping feature {key}")
229                     continue
230
231                 value = self.normalize_feature(value)
232
233                 if key == self.spec.label:
234                     y.append(value)
235                     wrote_label = True
236                 else:
237                     x.append(value)
238
239             # Make sure we saw a label and the requisite number of features.
240             if len(x) == self.spec.feature_count and wrote_label:
241                 X.append(x)
242                 self.file_done_count += 1
243             else:
244                 if wrote_label:
245                     y.pop()
246
247                 if self.spec.delete_bad_inputs:
248                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  DELETING."
249                     logger.warning(msg)
250                     warnings.warn(msg)
251                     os.remove(filename)
252                 else:
253                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  Skipping."
254                     logger.warning(msg)
255                     warnings.warn(msg)
256         return (X, y)
257
258     def make_progress_graph(self) -> None:
259         if not self.spec.quiet:
260             from text_utils import progress_graph
261
262             progress_graph(self.file_done_count, self.total_file_count)
263
264     @timed
265     def read_input_files(self):
266         import list_utils
267         import smart_future
268
269         # All features
270         X = []
271
272         # The label
273         y = []
274
275         results = []
276         all_files = glob.glob(self.spec.file_glob)
277         self.total_file_count = len(all_files)
278         for n, files in enumerate(list_utils.shard(all_files, 500)):
279             file_list = list(files)
280             results.append(self.read_files_from_list(file_list, n))
281
282         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
283             result = result._resolve()
284             for z in result[0]:
285                 X.append(z)
286             for z in result[1]:
287                 y.append(z)
288         if not self.spec.quiet:
289             print(" " * 80 + "\n")
290         return (X, y)
291
292     def normalize_feature(self, value: str) -> Any:
293         if value in ("False", "None"):
294             ret = 0
295         elif value == "True":
296             ret = 255
297         elif isinstance(value, str) and "." in value:
298             ret = round(float(value) * 100.0)
299         else:
300             ret = int(value)
301         return ret
302
303     def test_train_split(self, X, y) -> List:
304         logger.debug("Performing test/train split")
305         return train_test_split(
306             X,
307             y,
308             random_state=random.randrange(0, 1000),
309         )
310
311     def scale_data(
312         self, X_train: np.ndarray, X_test: np.ndarray
313     ) -> Tuple[Any, np.ndarray, np.ndarray]:
314         logger.debug("Scaling data")
315         scaler = MinMaxScaler()
316         scaler.fit(X_train)
317         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
318
319     # Note: children should implement.  Consider using @parallelize.
320     @abstractmethod
321     def train_model(
322         self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray
323     ) -> Any:
324         pass
325
326     def evaluate_model(
327         self,
328         model: Any,
329         X_train_scaled: np.ndarray,
330         y_train: np.ndarray,
331         X_test_scaled: np.ndarray,
332         y_test: np.ndarray,
333     ) -> Tuple[np.float64, np.float64]:
334         logger.debug("Evaluating the model")
335         training_score = model.score(X_train_scaled, y_train) * 100.0
336         test_score = model.score(X_test_scaled, y_test) * 100.0
337         logger.info(
338             f"Model evaluation results: test_score={test_score:.5f}, "
339             f"train_score={training_score:.5f}"
340         )
341         return (training_score, test_score)
342
343     def maybe_persist_scaler_and_model(
344         self,
345         training_score: np.float64,
346         test_score: np.float64,
347         params: str,
348         num_examples: int,
349         scaler: Any,
350         model: Any,
351     ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
352         if not self.spec.dry_run:
353             import datetime_utils
354             import input_utils
355             import string_utils
356
357             now: datetime.datetime = datetime_utils.now_pacific()
358             info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
359 Model params: {params}
360 Training examples: {num_examples}
361 Training set score: {training_score:.2f}%
362 Testing set score: {test_score:.2f}%"""
363             print(f'\n{info}\n')
364             if (
365                 self.spec.persist_percentage_threshold is not None
366                 and test_score > self.spec.persist_percentage_threshold
367             ) or (
368                 not self.spec.quiet
369                 and input_utils.yn_response("Write the model? [y,n]: ") == "y"
370             ):
371                 scaler_filename = f"{self.spec.basename}_scaler.sav"
372                 with open(scaler_filename, "wb") as f:
373                     pickle.dump(scaler, f)
374                 msg = f"Wrote {scaler_filename}"
375                 print(msg)
376                 logger.info(msg)
377                 model_filename = f"{self.spec.basename}_model.sav"
378                 with open(model_filename, "wb") as f:
379                     pickle.dump(model, f)
380                 msg = f"Wrote {model_filename}"
381                 print(msg)
382                 logger.info(msg)
383                 model_info_filename = f"{self.spec.basename}_model_info.txt"
384                 with open(model_info_filename, "w") as f:
385                     f.write(info)
386                 msg = f"Wrote {model_info_filename}:"
387                 print(msg)
388                 logger.info(msg)
389                 print(string_utils.indent(info, 2))
390                 logger.info(info)
391                 return (scaler_filename, model_filename, model_info_filename)
392         return (None, None, None)