Random cleanups and type safety. Created ml subdir.
[python_utils.git] / ml / model_trainer.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import datetime
7 import glob
8 import logging
9 import os
10 import pickle
11 import random
12 import sys
13 from types import SimpleNamespace
14 from typing import Any, List, NamedTuple, Optional, Set, Tuple
15
16 import numpy as np
17 from sklearn.model_selection import train_test_split  # type:ignore
18 from sklearn.preprocessing import MinMaxScaler  # type: ignore
19
20 from ansi import bold, reset
21 import argparse_utils
22 import config
23 from decorator_utils import timed
24 import parallelize as par
25
26 logger = logging.getLogger(__file__)
27
28 parser = config.add_commandline_args(
29     f"ML Model Trainer ({__file__})",
30     "Arguments related to training an ML model"
31 )
32 parser.add_argument(
33     "--ml_trainer_quiet",
34     action="store_true",
35     help="Don't prompt the user for anything."
36 )
37 parser.add_argument(
38     "--ml_trainer_delete",
39     action="store_true",
40     help="Delete invalid/incomplete features files in addition to warning."
41 )
42 group = parser.add_mutually_exclusive_group()
43 group.add_argument(
44     "--ml_trainer_dry_run",
45     action="store_true",
46     help="Do not write a new model, just report efficacy.",
47 )
48 group.add_argument(
49     "--ml_trainer_persist_threshold",
50     type=argparse_utils.valid_percentage,
51     metavar='0..100',
52     help="Persist the model if the test set score is >= this threshold.",
53 )
54
55
56 class InputSpec(SimpleNamespace):
57     file_glob: str
58     feature_count: int
59     features_to_skip: Set[str]
60     key_value_delimiter: str
61     training_parameters: List
62     label: str
63     basename: str
64     dry_run: Optional[bool]
65     quiet: Optional[bool]
66     persist_percentage_threshold: Optional[float]
67     delete_bad_inputs: Optional[bool]
68
69     @staticmethod
70     def populate_from_config() -> InputSpec:
71         return InputSpec(
72             dry_run = config.config["ml_trainer_dry_run"],
73             quiet = config.config["ml_trainer_quiet"],
74             persist_percentage_threshold = config.config["ml_trainer_persist_threshold"],
75             delete_bad_inputs = config.config["ml_trainer_delete"],
76         )
77
78
79 class OutputSpec(NamedTuple):
80     model_filename: Optional[str]
81     model_info_filename: Optional[str]
82     scaler_filename: Optional[str]
83     training_score: float
84     test_score: float
85
86
87 class TrainingBlueprint(ABC):
88     def __init__(self):
89         self.y_train = None
90         self.y_test = None
91         self.X_test_scaled = None
92         self.X_train_scaled = None
93         self.file_done_count = 0
94         self.total_file_count = 0
95         self.spec = None
96
97     def train(self, spec: InputSpec) -> OutputSpec:
98         import smart_future
99
100         random.seed()
101         self.spec = spec
102
103         X_, y_ = self.read_input_files()
104         num_examples = len(y_)
105
106         # Every example's features
107         X = np.array(X_)
108
109         # Every example's label
110         y = np.array(y_)
111
112         print("Doing random test/train split...")
113         X_train, X_test, self.y_train, self.y_test = self.test_train_split(
114             X,
115             y,
116         )
117
118         print("Scaling training data...")
119         scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
120             X_train,
121             X_test,
122         )
123
124         print("Training model(s)...")
125         models = []
126         modelid_to_params = {}
127         for params in self.spec.training_parameters:
128             model = self.train_model(
129                 params,
130                 self.X_train_scaled,
131                 self.y_train
132             )
133             models.append(model)
134             modelid_to_params[model.get_id()] = str(params)
135
136         best_model = None
137         best_score = None
138         best_test_score = None
139         best_training_score = None
140         best_params = None
141         for model in smart_future.wait_any(models):
142             params = modelid_to_params[model.get_id()]
143             if isinstance(model, smart_future.SmartFuture):
144                 model = model._resolve()
145             if model is not None:
146                 training_score, test_score = self.evaluate_model(
147                     model,
148                     self.X_train_scaled,
149                     self.y_train,
150                     self.X_test_scaled,
151                     self.y_test,
152                 )
153                 score = (training_score + test_score * 20) / 21
154                 if not self.spec.quiet:
155                     print(
156                         f"{bold()}{params}{reset()}: "
157                         f"Training set score={training_score:.2f}%, "
158                         f"test set score={test_score:.2f}%",
159                         file=sys.stderr,
160                     )
161                 if best_score is None or score > best_score:
162                     best_score = score
163                     best_test_score = test_score
164                     best_training_score = training_score
165                     best_model = model
166                     best_params = params
167                     if not self.spec.quiet:
168                         print(
169                             f"New best score {best_score:.2f}% with params {params}"
170                         )
171
172         if not self.spec.quiet:
173             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
174             print(msg)
175             logger.info(msg)
176         scaler_filename, model_filename, model_info_filename = (
177             self.maybe_persist_scaler_and_model(
178                 best_training_score,
179                 best_test_score,
180                 best_params,
181                 num_examples,
182                 scaler,
183                 best_model,
184             )
185         )
186         return OutputSpec(
187             model_filename = model_filename,
188             model_info_filename = model_info_filename,
189             scaler_filename = scaler_filename,
190             training_score = best_training_score,
191             test_score = best_test_score,
192         )
193
194     @par.parallelize(method=par.Method.THREAD)
195     def read_files_from_list(
196             self,
197             files: List[str],
198             n: int
199     ) -> Tuple[List, List]:
200         # All features
201         X = []
202
203         # The label
204         y = []
205
206         for filename in files:
207             wrote_label = False
208             with open(filename, "r") as f:
209                 lines = f.readlines()
210
211             # This example's features
212             x = []
213             for line in lines:
214
215                 # We expect lines in features files to be of the form:
216                 #
217                 # key: value
218                 line = line.strip()
219                 try:
220                     (key, value) = line.split(self.spec.key_value_delimiter)
221                 except Exception as e:
222                     logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
223                     continue
224
225                 key = key.strip()
226                 value = value.strip()
227                 if (self.spec.features_to_skip is not None
228                         and key in self.spec.features_to_skip):
229                     logger.debug(f"Skipping feature {key}")
230                     continue
231
232                 value = self.normalize_feature(value)
233
234                 if key == self.spec.label:
235                     y.append(value)
236                     wrote_label = True
237                 else:
238                     x.append(value)
239
240             # Make sure we saw a label and the requisite number of features.
241             if len(x) == self.spec.feature_count and wrote_label:
242                 X.append(x)
243                 self.file_done_count += 1
244             else:
245                 if wrote_label:
246                     y.pop()
247
248                 if self.spec.delete_bad_inputs:
249                     msg = f"WARNING: {filename}: missing features or label.  DELETING."
250                     print(msg, file=sys.stderr)
251                     logger.warning(msg)
252                     os.remove(filename)
253                 else:
254                     msg = f"WARNING: {filename}: missing features or label.  Skipped."
255                     print(msg, file=sys.stderr)
256                     logger.warning(msg)
257         return (X, y)
258
259     def make_progress_graph(self) -> None:
260         if not self.spec.quiet:
261             from text_utils import progress_graph
262             progress_graph(
263                 self.file_done_count,
264                 self.total_file_count
265             )
266
267     @timed
268     def read_input_files(self):
269         import list_utils
270         import smart_future
271
272         # All features
273         X = []
274
275         # The label
276         y = []
277
278         results = []
279         all_files = glob.glob(self.spec.file_glob)
280         self.total_file_count = len(all_files)
281         for n, files in enumerate(list_utils.shard(all_files, 500)):
282             file_list = list(files)
283             results.append(self.read_files_from_list(file_list, n))
284
285         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
286             result = result._resolve()
287             for z in result[0]:
288                 X.append(z)
289             for z in result[1]:
290                 y.append(z)
291         if not self.spec.quiet:
292             print(" " * 80 + "\n")
293         return (X, y)
294
295     def normalize_feature(self, value: str) -> Any:
296         if value in ("False", "None"):
297             ret = 0
298         elif value == "True":
299             ret = 255
300         elif isinstance(value, str) and "." in value:
301             ret = round(float(value) * 100.0)
302         else:
303             ret = int(value)
304         return ret
305
306     def test_train_split(self, X, y) -> List:
307         logger.debug("Performing test/train split")
308         return train_test_split(
309             X,
310             y,
311             random_state=random.randrange(0, 1000),
312         )
313
314     def scale_data(self,
315                    X_train: np.ndarray,
316                    X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
317         logger.debug("Scaling data")
318         scaler = MinMaxScaler()
319         scaler.fit(X_train)
320         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
321
322     # Note: children should implement.  Consider using @parallelize.
323     @abstractmethod
324     def train_model(self,
325                     parameters,
326                     X_train_scaled: np.ndarray,
327                     y_train: np.ndarray) -> Any:
328         pass
329
330     def evaluate_model(
331             self,
332             model: Any,
333             X_train_scaled: np.ndarray,
334             y_train: np.ndarray,
335             X_test_scaled: np.ndarray,
336             y_test: np.ndarray) -> Tuple[np.float64, np.float64]:
337         logger.debug("Evaluating the model")
338         training_score = model.score(X_train_scaled, y_train) * 100.0
339         test_score = model.score(X_test_scaled, y_test) * 100.0
340         logger.info(
341             f"Model evaluation results: test_score={test_score:.5f}, "
342             f"train_score={training_score:.5f}"
343         )
344         return (training_score, test_score)
345
346     def maybe_persist_scaler_and_model(
347             self,
348             training_score: np.float64,
349             test_score: np.float64,
350             params: str,
351             num_examples: int,
352             scaler: Any,
353             model: Any) -> Tuple[Optional[str], Optional[str], Optional[str]]:
354         if not self.spec.dry_run:
355             import datetime_utils
356             import input_utils
357             import string_utils
358
359             if (
360                     (self.spec.persist_percentage_threshold is not None and
361                      test_score > self.spec.persist_percentage_threshold)
362                     or
363                     (not self.spec.quiet
364                      and input_utils.yn_response("Write the model? [y,n]: ") == "y")
365             ):
366                 scaler_filename = f"{self.spec.basename}_scaler.sav"
367                 with open(scaler_filename, "wb") as f:
368                     pickle.dump(scaler, f)
369                 msg = f"Wrote {scaler_filename}"
370                 print(msg)
371                 logger.info(msg)
372                 model_filename = f"{self.spec.basename}_model.sav"
373                 with open(model_filename, "wb") as f:
374                     pickle.dump(model, f)
375                 msg = f"Wrote {model_filename}"
376                 print(msg)
377                 logger.info(msg)
378                 model_info_filename = f"{self.spec.basename}_model_info.txt"
379                 now: datetime.datetime = datetime_utils.now_pst()
380                 info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
381 Model params: {params}
382 Training examples: {num_examples}
383 Training set score: {training_score:.2f}%
384 Testing set score: {test_score:.2f}%"""
385                 with open(model_info_filename, "w") as f:
386                     f.write(info)
387                 msg = f"Wrote {model_info_filename}:"
388                 print(msg)
389                 logger.info(msg)
390                 print(string_utils.indent(info, 2))
391                 logger.info(info)
392                 return (scaler_filename, model_filename, model_info_filename)
393         return (None, None, None)