Reduce the doctest lease duration...
[python_utils.git] / ml / model_trainer.py
index e3d89c20421619533da6c8fdcddee739ed33ddff..07f7b99292c9c9a3c3ac3a685c40ca59ee1b9582 100644 (file)
@@ -12,6 +12,7 @@ import os
 import pickle
 import random
 import sys
+import time
 import warnings
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
@@ -143,6 +144,7 @@ class TrainingBlueprint(ABC):
             models.append(model)
             modelid_to_params[model.get_id()] = str(params)
 
+        all_models = {}
         best_model = None
         best_score: Optional[np.float64] = None
         best_test_score: Optional[np.float64] = None
@@ -161,6 +163,7 @@ class TrainingBlueprint(ABC):
                     self.y_test,
                 )
                 score = (training_score + test_score * 20) / 21
+                all_models[params] = (score, training_score, test_score)
                 if not self.spec.quiet:
                     print(
                         f"{bold()}{params}{reset()}: "
@@ -177,15 +180,22 @@ class TrainingBlueprint(ABC):
                     if not self.spec.quiet:
                         print(f"New best score {best_score:.2f}% with params {params}")
 
-        if not self.spec.quiet:
-            executors.DefaultExecutors().shutdown()
-            msg = f"Done training; best test set score was: {best_test_score:.1f}%"
-            print(msg)
-            logger.info(msg)
-
+        executors.DefaultExecutors().shutdown()
         assert best_training_score is not None
         assert best_test_score is not None
         assert best_params is not None
+
+        if not self.spec.quiet:
+            time.sleep(1.0)
+            print('Done training...')
+            for params in all_models:
+                msg = f'{bold()}{params}{reset()}: score={all_models[params][0]:.2f}% '
+                msg += f'({all_models[params][2]:.2f}% test, '
+                msg += f'{all_models[params][1]:.2f}% train)'
+                if params == best_params:
+                    msg += f'{bold()} <-- winner{reset()}'
+                print(msg)
+
         (
             scaler_filename,
             model_filename,
@@ -268,9 +278,9 @@ class TrainingBlueprint(ABC):
 
     def make_progress_graph(self) -> None:
         if not self.spec.quiet:
-            from text_utils import progress_graph
+            from text_utils import bar_graph
 
-            progress_graph(self.file_done_count, self.total_file_count)
+            bar_graph(self.file_done_count, self.total_file_count)
 
     @timed
     def read_input_files(self):