Rename simple_acl -> acl
[python_utils.git] / ml_model_trainer.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from abc import ABC, abstractmethod
6 import datetime
7 import glob
8 import logging
9 import os
10 import pickle
11 import random
12 import sys
13 from types import SimpleNamespace
14 from typing import Any, List, NamedTuple, Optional, Set, Tuple
15
16 import numpy as np
17 from sklearn.model_selection import train_test_split  # type:ignore
18 from sklearn.preprocessing import MinMaxScaler  # type: ignore
19
20 from ansi import bold, reset
21 import argparse_utils
22 import config
23 import datetime_utils
24 import decorator_utils
25 import input_utils
26 import list_utils
27 import parallelize as par
28 import smart_future
29 import string_utils
30 import text_utils
31
32 logger = logging.getLogger(__file__)
33
34 parser = config.add_commandline_args(
35     f"ML Model Trainer ({__file__})",
36     "Arguments related to training an ML model"
37 )
38 parser.add_argument(
39     "--ml_trainer_quiet",
40     action="store_true",
41     help="Don't prompt the user for anything."
42 )
43 parser.add_argument(
44     "--ml_trainer_delete",
45     action="store_true",
46     help="Delete invalid/incomplete features files in addition to warning."
47 )
48 group = parser.add_mutually_exclusive_group()
49 group.add_argument(
50     "--ml_trainer_dry_run",
51     action="store_true",
52     help="Do not write a new model, just report efficacy.",
53 )
54 group.add_argument(
55     "--ml_trainer_persist_threshold",
56     type=argparse_utils.valid_percentage,
57     metavar='0..100',
58     help="Persist the model if the test set score is >= this threshold.",
59 )
60
61
62 class InputSpec(SimpleNamespace):
63     file_glob: str
64     feature_count: int
65     features_to_skip: Set[str]
66     key_value_delimiter: str
67     training_parameters: List
68     label: str
69     basename: str
70     dry_run: Optional[bool]
71     quiet: Optional[bool]
72     persist_percentage_threshold: Optional[float]
73     delete_bad_inputs: Optional[bool]
74
75     @staticmethod
76     def populate_from_config() -> InputSpec:
77         return InputSpec(
78             dry_run = config.config["ml_trainer_dry_run"],
79             quiet = config.config["ml_trainer_quiet"],
80             persist_percentage_threshold = config.config["ml_trainer_persist_threshold"],
81             delete_bad_inputs = config.config["ml_trainer_delete"],
82         )
83
84
85 class OutputSpec(NamedTuple):
86     model_filename: Optional[str]
87     model_info_filename: Optional[str]
88     scaler_filename: Optional[str]
89     training_score: float
90     test_score: float
91
92
93 class TrainingBlueprint(ABC):
94     def __init__(self):
95         self.y_train = None
96         self.y_test = None
97         self.X_test_scaled = None
98         self.X_train_scaled = None
99         self.file_done_count = 0
100         self.total_file_count = 0
101         self.spec = None
102
103     def train(self, spec: InputSpec) -> OutputSpec:
104         random.seed()
105         self.spec = spec
106
107         X_, y_ = self.read_input_files()
108         num_examples = len(y_)
109
110         # Every example's features
111         X = np.array(X_)
112
113         # Every example's label
114         y = np.array(y_)
115
116         print("Doing random test/train split...")
117         X_train, X_test, self.y_train, self.y_test = self.test_train_split(
118             X,
119             y,
120         )
121
122         print("Scaling training data...")
123         scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
124             X_train,
125             X_test,
126         )
127
128         print("Training model(s)...")
129         models = []
130         modelid_to_params = {}
131         for params in self.spec.training_parameters:
132             model = self.train_model(
133                 params,
134                 self.X_train_scaled,
135                 self.y_train
136             )
137             models.append(model)
138             modelid_to_params[model.get_id()] = str(params)
139
140         best_model = None
141         best_score = None
142         best_test_score = None
143         best_training_score = None
144         best_params = None
145         for model in smart_future.wait_any(models):
146             params = modelid_to_params[model.get_id()]
147             if isinstance(model, smart_future.SmartFuture):
148                 model = model._resolve()
149             if model is not None:
150                 training_score, test_score = self.evaluate_model(
151                     model,
152                     self.X_train_scaled,
153                     self.y_train,
154                     self.X_test_scaled,
155                     self.y_test,
156                 )
157                 score = (training_score + test_score * 20) / 21
158                 if not self.spec.quiet:
159                     print(
160                         f"{bold()}{params}{reset()}: "
161                         f"Training set score={training_score:.2f}%, "
162                         f"test set score={test_score:.2f}%",
163                         file=sys.stderr,
164                     )
165                 if best_score is None or score > best_score:
166                     best_score = score
167                     best_test_score = test_score
168                     best_training_score = training_score
169                     best_model = model
170                     best_params = params
171                     if not self.spec.quiet:
172                         print(
173                             f"New best score {best_score:.2f}% with params {params}"
174                         )
175
176         if not self.spec.quiet:
177             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
178             print(msg)
179             logger.info(msg)
180         scaler_filename, model_filename, model_info_filename = (
181             self.maybe_persist_scaler_and_model(
182                 best_training_score,
183                 best_test_score,
184                 best_params,
185                 num_examples,
186                 scaler,
187                 best_model,
188             )
189         )
190         return OutputSpec(
191             model_filename = model_filename,
192             model_info_filename = model_info_filename,
193             scaler_filename = scaler_filename,
194             training_score = best_training_score,
195             test_score = best_test_score,
196         )
197
198     @par.parallelize(method=par.Method.THREAD)
199     def read_files_from_list(
200             self,
201             files: List[str],
202             n: int
203     ) -> Tuple[List, List]:
204         # All features
205         X = []
206
207         # The label
208         y = []
209
210         for filename in files:
211             wrote_label = False
212             with open(filename, "r") as f:
213                 lines = f.readlines()
214
215             # This example's features
216             x = []
217             for line in lines:
218
219                 # We expect lines in features files to be of the form:
220                 #
221                 # key: value
222                 line = line.strip()
223                 try:
224                     (key, value) = line.split(self.spec.key_value_delimiter)
225                 except Exception as e:
226                     logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
227                     continue
228
229                 key = key.strip()
230                 value = value.strip()
231                 if (self.spec.features_to_skip is not None
232                         and key in self.spec.features_to_skip):
233                     logger.debug(f"Skipping feature {key}")
234                     continue
235
236                 value = self.normalize_feature(value)
237
238                 if key == self.spec.label:
239                     y.append(value)
240                     wrote_label = True
241                 else:
242                     x.append(value)
243
244             # Make sure we saw a label and the requisite number of features.
245             if len(x) == self.spec.feature_count and wrote_label:
246                 X.append(x)
247                 self.file_done_count += 1
248             else:
249                 if wrote_label:
250                     y.pop()
251
252                 if self.spec.delete_bad_inputs:
253                     msg = f"WARNING: {filename}: missing features or label.  DELETING."
254                     print(msg, file=sys.stderr)
255                     logger.warning(msg)
256                     os.remove(filename)
257                 else:
258                     msg = f"WARNING: {filename}: missing features or label.  Skipped."
259                     print(msg, file=sys.stderr)
260                     logger.warning(msg)
261         return (X, y)
262
263     def make_progress_graph(self) -> None:
264         if not self.spec.quiet:
265             text_utils.progress_graph(self.file_done_count,
266                                       self.total_file_count)
267
268     @decorator_utils.timed
269     def read_input_files(self):
270         # All features
271         X = []
272
273         # The label
274         y = []
275
276         results = []
277         all_files = glob.glob(self.spec.file_glob)
278         self.total_file_count = len(all_files)
279         for n, files in enumerate(list_utils.shard(all_files, 500)):
280             file_list = list(files)
281             results.append(self.read_files_from_list(file_list, n))
282
283         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
284             result = result._resolve()
285             for z in result[0]:
286                 X.append(z)
287             for z in result[1]:
288                 y.append(z)
289         if not self.spec.quiet:
290             print(" " * 80 + "\n")
291         return (X, y)
292
293     def normalize_feature(self, value: str) -> Any:
294         if value in ("False", "None"):
295             ret = 0
296         elif value == "True":
297             ret = 255
298         elif isinstance(value, str) and "." in value:
299             ret = round(float(value) * 100.0)
300         else:
301             ret = int(value)
302         return ret
303
304     def test_train_split(self, X, y) -> List:
305         logger.debug("Performing test/train split")
306         return train_test_split(
307             X,
308             y,
309             random_state=random.randrange(0, 1000),
310         )
311
312     def scale_data(self,
313                    X_train: np.ndarray,
314                    X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
315         logger.debug("Scaling data")
316         scaler = MinMaxScaler()
317         scaler.fit(X_train)
318         return (scaler, scaler.transform(X_train), scaler.transform(X_test))
319
320     # Note: children should implement.  Consider using @parallelize.
321     @abstractmethod
322     def train_model(self,
323                     parameters,
324                     X_train_scaled: np.ndarray,
325                     y_train: np.ndarray) -> Any:
326         pass
327
328     def evaluate_model(
329             self,
330             model: Any,
331             X_train_scaled: np.ndarray,
332             y_train: np.ndarray,
333             X_test_scaled: np.ndarray,
334             y_test: np.ndarray) -> Tuple[np.float64, np.float64]:
335         logger.debug("Evaluating the model")
336         training_score = model.score(X_train_scaled, y_train) * 100.0
337         test_score = model.score(X_test_scaled, y_test) * 100.0
338         logger.info(
339             f"Model evaluation results: test_score={test_score:.5f}, "
340             f"train_score={training_score:.5f}"
341         )
342         return (training_score, test_score)
343
344     def maybe_persist_scaler_and_model(
345             self,
346             training_score: np.float64,
347             test_score: np.float64,
348             params: str,
349             num_examples: int,
350             scaler: Any,
351             model: Any) -> Tuple[Optional[str], Optional[str], Optional[str]]:
352         if not self.spec.dry_run:
353             if (
354                     (self.spec.persist_percentage_threshold is not None and
355                      test_score > self.spec.persist_percentage_threshold)
356                     or
357                     (not self.spec.quiet
358                      and input_utils.yn_response("Write the model? [y,n]: ") == "y")
359             ):
360                 scaler_filename = f"{self.spec.basename}_scaler.sav"
361                 with open(scaler_filename, "wb") as f:
362                     pickle.dump(scaler, f)
363                 msg = f"Wrote {scaler_filename}"
364                 print(msg)
365                 logger.info(msg)
366                 model_filename = f"{self.spec.basename}_model.sav"
367                 with open(model_filename, "wb") as f:
368                     pickle.dump(model, f)
369                 msg = f"Wrote {model_filename}"
370                 print(msg)
371                 logger.info(msg)
372                 model_info_filename = f"{self.spec.basename}_model_info.txt"
373                 now: datetime.datetime = datetime_utils.now_pst()
374                 info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
375 Model params: {params}
376 Training examples: {num_examples}
377 Training set score: {training_score:.2f}%
378 Testing set score: {test_score:.2f}%"""
379                 with open(model_info_filename, "w") as f:
380                     f.write(info)
381                 msg = f"Wrote {model_info_filename}:"
382                 print(msg)
383                 logger.info(msg)
384                 print(string_utils.indent(info, 2))
385                 logger.info(info)
386                 return (scaler_filename, model_filename, model_info_filename)
387         return (None, None, None)