Clean up the remote executor stuff and create a dedicated heartbeat
[python_utils.git] / ml / model_trainer.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import datetime
7 import glob
8 import logging
9 import os
10 import pickle
11 import random
12 import sys
13 from types import SimpleNamespace
14 from typing import Any, List, NamedTuple, Optional, Set, Tuple
15 import warnings
16
17 import numpy as np
18 from sklearn.model_selection import train_test_split  # type:ignore
19 from sklearn.preprocessing import MinMaxScaler  # type: ignore
20
21 from ansi import bold, reset
22 import argparse_utils
23 import config
24 from decorator_utils import timed
25 import executors
26 import parallelize as par
27
28 logger = logging.getLogger(__file__)
29
30 parser = config.add_commandline_args(
31     f"ML Model Trainer ({__file__})",
32     "Arguments related to training an ML model"
33 )
34 parser.add_argument(
35     "--ml_trainer_quiet",
36     action="store_true",
37     help="Don't prompt the user for anything."
38 )
39 parser.add_argument(
40     "--ml_trainer_delete",
41     action="store_true",
42     help="Delete invalid/incomplete features files in addition to warning."
43 )
44 group = parser.add_mutually_exclusive_group()
45 group.add_argument(
46     "--ml_trainer_dry_run",
47     action="store_true",
48     help="Do not write a new model, just report efficacy.",
49 )
50 group.add_argument(
51     "--ml_trainer_persist_threshold",
52     type=argparse_utils.valid_percentage,
53     metavar='0..100',
54     help="Persist the model if the test set score is >= this threshold.",
55 )
56
57
58 class InputSpec(SimpleNamespace):
59     file_glob: str
60     feature_count: int
61     features_to_skip: Set[str]
62     key_value_delimiter: str
63     training_parameters: List
64     label: str
65     basename: str
66     dry_run: Optional[bool]
67     quiet: Optional[bool]
68     persist_percentage_threshold: Optional[float]
69     delete_bad_inputs: Optional[bool]
70
71     @staticmethod
72     def populate_from_config() -> InputSpec:
73         return InputSpec(
74             dry_run = config.config["ml_trainer_dry_run"],
75             quiet = config.config["ml_trainer_quiet"],
76             persist_percentage_threshold = config.config["ml_trainer_persist_threshold"],
77             delete_bad_inputs = config.config["ml_trainer_delete"],
78         )
79
80
81 class OutputSpec(NamedTuple):
82     model_filename: Optional[str]
83     model_info_filename: Optional[str]
84     scaler_filename: Optional[str]
85     training_score: float
86     test_score: float
87
88
89 class TrainingBlueprint(ABC):
90     def __init__(self):
91         self.y_train = None
92         self.y_test = None
93         self.X_test_scaled = None
94         self.X_train_scaled = None
95         self.file_done_count = 0
96         self.total_file_count = 0
97         self.spec = None
98
99     def train(self, spec: InputSpec) -> OutputSpec:
100         import smart_future
101
102         random.seed()
103         self.spec = spec
104
105         X_, y_ = self.read_input_files()
106         num_examples = len(y_)
107
108         # Every example's features
109         X = np.array(X_)
110
111         # Every example's label
112         y = np.array(y_)
113
114         print("Doing random test/train split...")
115         X_train, X_test, self.y_train, self.y_test = self.test_train_split(
116             X,
117             y,
118         )
119
120         print("Scaling training data...")
121         scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
122             X_train,
123             X_test,
124         )
125
126         print("Training model(s)...")
127         models = []
128         modelid_to_params = {}
129         for params in self.spec.training_parameters:
130             model = self.train_model(
131                 params,
132                 self.X_train_scaled,
133                 self.y_train
134             )
135             models.append(model)
136             modelid_to_params[model.get_id()] = str(params)
137
138         best_model = None
139         best_score = None
140         best_test_score = None
141         best_training_score = None
142         best_params = None
143         for model in smart_future.wait_any(models):
144             params = modelid_to_params[model.get_id()]
145             if isinstance(model, smart_future.SmartFuture):
146                 model = model._resolve()
147             if model is not None:
148                 training_score, test_score = self.evaluate_model(
149                     model,
150                     self.X_train_scaled,
151                     self.y_train,
152                     self.X_test_scaled,
153                     self.y_test,
154                 )
155                 score = (training_score + test_score * 20) / 21
156                 if not self.spec.quiet:
157                     print(
158                         f"{bold()}{params}{reset()}: "
159                         f"Training set score={training_score:.2f}%, "
160                         f"test set score={test_score:.2f}%",
161                         file=sys.stderr,
162                     )
163                 if best_score is None or score > best_score:
164                     best_score = score
165                     best_test_score = test_score
166                     best_training_score = training_score
167                     best_model = model
168                     best_params = params
169                     if not self.spec.quiet:
170                         print(
171                             f"New best score {best_score:.2f}% with params {params}"
172                         )
173
174         if not self.spec.quiet:
175             executors.DefaultExecutors().shutdown()
176             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
177             print(msg)
178             logger.info(msg)
179
180         scaler_filename, model_filename, model_info_filename = (
181             self.maybe_persist_scaler_and_model(
182                 best_training_score,
183                 best_test_score,
184                 best_params,
185                 num_examples,
186                 scaler,
187                 best_model,
188             )
189         )
190         return OutputSpec(
191             model_filename = model_filename,
192             model_info_filename = model_info_filename,
193             scaler_filename = scaler_filename,
194             training_score = best_training_score,
195             test_score = best_test_score,
196         )
197
198     @par.parallelize(method=par.Method.THREAD)
199     def read_files_from_list(
200             self,
201             files: List[str],
202             n: int
203     ) -> Tuple[List, List]:
204         # All features
205         X = []
206
207         # The label
208         y = []
209
210         for filename in files:
211             wrote_label = False
212             with open(filename, "r") as f:
213                 lines = f.readlines()
214
215             # This example's features
216             x = []
217             for line in lines:
218
219                 # We expect lines in features files to be of the form:
220                 #
221                 # key: value
222                 line = line.strip()
223                 try:
224                     (key, value) = line.split(self.spec.key_value_delimiter)
225                 except Exception:
226                     logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
227                     continue
228
229                 key = key.strip()
230                 value = value.strip()
231                 if (self.spec.features_to_skip is not None
232                         and key in self.spec.features_to_skip):
233                     logger.debug(f"Skipping feature {key}")
234                     continue
235
236                 value = self.normalize_feature(value)
237
238                 if key == self.spec.label:
239                     y.append(value)
240                     wrote_label = True
241                 else:
242                     x.append(value)
243
244             # Make sure we saw a label and the requisite number of features.
245             if len(x) == self.spec.feature_count and wrote_label:
246                 X.append(x)
247                 self.file_done_count += 1
248             else:
249                 if wrote_label:
250                     y.pop()
251
252                 if self.spec.delete_bad_inputs:
253                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  DELETING."
254                     logger.warning(msg)
255                     warnings.warn(msg)
256                     os.remove(filename)
257                 else:
258                     msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  Skipping."
259                     logger.warning(msg)
260                     warnings.warn(msg)
261         return (X, y)
262
263     def make_progress_graph(self) -> None:
264         if not self.spec.quiet:
265             from text_utils import progress_graph
266             progress_graph(
267                 self.file_done_count,
268                 self.total_file_count
269             )
270
271     @timed
272     def read_input_files(self):
273         import list_utils
274         import smart_future
275
276         # All features
277         X = []
278
279         # The label
280         y = []
281
282         results = []
283         all_files = glob.glob(self.spec.file_glob)
284         self.total_file_count = len(all_files)
285         for n, files in enumerate(list_utils.shard(all_files, 500)):
286             file_list = list(files)
287             results.append(self.read_files_from_list(file_list, n))
288
289         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
290             result = result._resolve()
291             for z in result[0]:
292                 X.append(z)
293             for z in result[1]:
294                 y.append(z)
295         if not self.spec.quiet:
296             print(" " * 80 + "\n")
297         return (X, y)
298
299     def normalize_feature(self, value: str) -> Any:
300         if value in ("False", "None"):
301             ret = 0
302         elif value == "True":
303             ret = 255
304         elif isinstance(value, str) and "." in value:
305             ret = round(float(value) * 100.0)
306         else:
307             ret = int(value)
308         return ret
309
310     def test_train_split(self, X, y) -> List:
311         logger.debug("Performing test/train split")
312         return train_test_split(
313             X,
314             y,
315             random_state=random.randrange(0, 1000),
316         )
317
318     def scale_data(self,
319                    X_train: np.ndarray,
320                    X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
321         logger.debug("Scaling data")
322         scaler = MinMaxScaler()
323         scaler.fit(X_train)
324         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
325
326     # Note: children should implement.  Consider using @parallelize.
327     @abstractmethod
328     def train_model(self,
329                     parameters,
330                     X_train_scaled: np.ndarray,
331                     y_train: np.ndarray) -> Any:
332         pass
333
334     def evaluate_model(
335             self,
336             model: Any,
337             X_train_scaled: np.ndarray,
338             y_train: np.ndarray,
339             X_test_scaled: np.ndarray,
340             y_test: np.ndarray) -> Tuple[np.float64, np.float64]:
341         logger.debug("Evaluating the model")
342         training_score = model.score(X_train_scaled, y_train) * 100.0
343         test_score = model.score(X_test_scaled, y_test) * 100.0
344         logger.info(
345             f"Model evaluation results: test_score={test_score:.5f}, "
346             f"train_score={training_score:.5f}"
347         )
348         return (training_score, test_score)
349
350     def maybe_persist_scaler_and_model(
351             self,
352             training_score: np.float64,
353             test_score: np.float64,
354             params: str,
355             num_examples: int,
356             scaler: Any,
357             model: Any) -> Tuple[Optional[str], Optional[str], Optional[str]]:
358         if not self.spec.dry_run:
359             import datetime_utils
360             import input_utils
361             import string_utils
362
363             now: datetime.datetime = datetime_utils.now_pacific()
364             info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
365 Model params: {params}
366 Training examples: {num_examples}
367 Training set score: {training_score:.2f}%
368 Testing set score: {test_score:.2f}%"""
369             print(f'\n{info}\n')
370             if (
371                     (self.spec.persist_percentage_threshold is not None and
372                      test_score > self.spec.persist_percentage_threshold)
373                     or
374                     (not self.spec.quiet
375                      and input_utils.yn_response("Write the model? [y,n]: ") == "y")
376             ):
377                 scaler_filename = f"{self.spec.basename}_scaler.sav"
378                 with open(scaler_filename, "wb") as f:
379                     pickle.dump(scaler, f)
380                 msg = f"Wrote {scaler_filename}"
381                 print(msg)
382                 logger.info(msg)
383                 model_filename = f"{self.spec.basename}_model.sav"
384                 with open(model_filename, "wb") as f:
385                     pickle.dump(model, f)
386                 msg = f"Wrote {model_filename}"
387                 print(msg)
388                 logger.info(msg)
389                 model_info_filename = f"{self.spec.basename}_model_info.txt"
390                 with open(model_info_filename, "w") as f:
391                     f.write(info)
392                 msg = f"Wrote {model_info_filename}:"
393                 print(msg)
394                 logger.info(msg)
395                 print(string_utils.indent(info, 2))
396                 logger.info(info)
397                 return (scaler_filename, model_filename, model_info_filename)
398         return (None, None, None)