Improve docstrings for sphinx.
[python_utils.git] / ml / model_trainer.py
index 15095770cf71f1c7b4fd05e5d3c10a7488775792..34ded741a21131b8f7638ebf475374038c3e6101 100644 (file)
@@ -1,5 +1,7 @@
 #!/usr/bin/env python3
 
+# © Copyright 2021-2022, Scott Gasch
+
 """This is a blueprint for training sklearn ML models."""
 
 from __future__ import annotations
@@ -10,6 +12,7 @@ import os
 import pickle
 import random
 import sys
+import time
 import warnings
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
@@ -141,6 +144,7 @@ class TrainingBlueprint(ABC):
             models.append(model)
             modelid_to_params[model.get_id()] = str(params)
 
+        all_models = {}
         best_model = None
         best_score: Optional[np.float64] = None
         best_test_score: Optional[np.float64] = None
@@ -159,6 +163,7 @@ class TrainingBlueprint(ABC):
                     self.y_test,
                 )
                 score = (training_score + test_score * 20) / 21
+                all_models[params] = (score, training_score, test_score)
                 if not self.spec.quiet:
                     print(
                         f"{bold()}{params}{reset()}: "
@@ -175,15 +180,22 @@ class TrainingBlueprint(ABC):
                     if not self.spec.quiet:
                         print(f"New best score {best_score:.2f}% with params {params}")
 
-        if not self.spec.quiet:
-            executors.DefaultExecutors().shutdown()
-            msg = f"Done training; best test set score was: {best_test_score:.1f}%"
-            print(msg)
-            logger.info(msg)
-
+        executors.DefaultExecutors().shutdown()
         assert best_training_score is not None
         assert best_test_score is not None
         assert best_params is not None
+
+        if not self.spec.quiet:
+            time.sleep(1.0)
+            print('Done training...')
+            for params in all_models:
+                msg = f'{bold()}{params}{reset()}: score={all_models[params][0]:.2f}% '
+                msg += f'({all_models[params][2]:.2f}% test, '
+                msg += f'{all_models[params][1]:.2f}% train)'
+                if params == best_params:
+                    msg += f'{bold()} <-- winner{reset()}'
+                print(msg)
+
         (
             scaler_filename,
             model_filename,