Make unittest_utils log its perf data to a database, optionally.
[python_utils.git] / ml / model_trainer.py
index f9e132e18aa20ecf2461db55257b6037a0c13a4e..79ce7062b5b4a05616cddcf7b3d35d59cfbef007 100644 (file)
@@ -12,6 +12,7 @@ import random
 import sys
 from types import SimpleNamespace
 from typing import Any, List, NamedTuple, Optional, Set, Tuple
 import sys
 from types import SimpleNamespace
 from typing import Any, List, NamedTuple, Optional, Set, Tuple
+import warnings
 
 import numpy as np
 from sklearn.model_selection import train_test_split  # type:ignore
 
 import numpy as np
 from sklearn.model_selection import train_test_split  # type:ignore
@@ -246,14 +247,14 @@ class TrainingBlueprint(ABC):
                     y.pop()
 
                 if self.spec.delete_bad_inputs:
                     y.pop()
 
                 if self.spec.delete_bad_inputs:
-                    msg = f"WARNING: {filename}: missing features or label.  DELETING."
-                    print(msg, file=sys.stderr)
+                    msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  DELETING."
                     logger.warning(msg)
                     logger.warning(msg)
+                    warnings.warn(msg)
                     os.remove(filename)
                 else:
                     os.remove(filename)
                 else:
-                    msg = f"WARNING: {filename}: missing features or label.  Skipped."
-                    print(msg, file=sys.stderr)
+                    msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}.  Skipping."
                     logger.warning(msg)
                     logger.warning(msg)
+                    warnings.warn(msg)
         return (X, y)
 
     def make_progress_graph(self) -> None:
         return (X, y)
 
     def make_progress_graph(self) -> None:
@@ -356,6 +357,13 @@ class TrainingBlueprint(ABC):
             import input_utils
             import string_utils
 
             import input_utils
             import string_utils
 
+            now: datetime.datetime = datetime_utils.now_pacific()
+            info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
+Model params: {params}
+Training examples: {num_examples}
+Training set score: {training_score:.2f}%
+Testing set score: {test_score:.2f}%"""
+            print(f'\n{info}\n')
             if (
                     (self.spec.persist_percentage_threshold is not None and
                      test_score > self.spec.persist_percentage_threshold)
             if (
                     (self.spec.persist_percentage_threshold is not None and
                      test_score > self.spec.persist_percentage_threshold)
@@ -376,12 +384,6 @@ class TrainingBlueprint(ABC):
                 print(msg)
                 logger.info(msg)
                 model_info_filename = f"{self.spec.basename}_model_info.txt"
                 print(msg)
                 logger.info(msg)
                 model_info_filename = f"{self.spec.basename}_model_info.txt"
-                now: datetime.datetime = datetime_utils.now_pacific()
-                info = f"""Timestamp: {datetime_utils.datetime_to_string(now)}
-Model params: {params}
-Training examples: {num_examples}
-Training set score: {training_score:.2f}%
-Testing set score: {test_score:.2f}%"""
                 with open(model_info_filename, "w") as f:
                     f.write(info)
                 msg = f"Wrote {model_info_filename}:"
                 with open(model_info_filename, "w") as f:
                     f.write(info)
                 msg = f"Wrote {model_info_filename}:"