ACL uses enums, some more tests, other stuff.
[python_utils.git] / ml_model_trainer.py
index 22735c90c87f585bfe115521786eda97ef887a49..ab3059f855388d06b8077a359897bb07ef5b2bc9 100644 (file)
@@ -20,14 +20,8 @@ from sklearn.preprocessing import MinMaxScaler  # type: ignore
 from ansi import bold, reset
 import argparse_utils
 import config
-import datetime_utils
-import decorator_utils
-import input_utils
-import list_utils
+from decorator_utils import timed
 import parallelize as par
-import smart_future
-import string_utils
-import text_utils
 
 logger = logging.getLogger(__file__)
 
@@ -52,10 +46,10 @@ group.add_argument(
     help="Do not write a new model, just report efficacy.",
 )
 group.add_argument(
-    "--ml_trainer_predicate",
+    "--ml_trainer_persist_threshold",
     type=argparse_utils.valid_percentage,
     metavar='0..100',
-    help="Persist the model if the test set score is >= this predicate.",
+    help="Persist the model if the test set score is >= this threshold.",
 )
 
 
@@ -69,7 +63,7 @@ class InputSpec(SimpleNamespace):
     basename: str
     dry_run: Optional[bool]
     quiet: Optional[bool]
-    persist_predicate: Optional[float]
+    persist_percentage_threshold: Optional[float]
     delete_bad_inputs: Optional[bool]
 
     @staticmethod
@@ -77,7 +71,7 @@ class InputSpec(SimpleNamespace):
         return InputSpec(
             dry_run = config.config["ml_trainer_dry_run"],
             quiet = config.config["ml_trainer_quiet"],
-            persist_predicate = config.config["ml_trainer_predicate"],
+            persist_percentage_threshold = config.config["ml_trainer_persist_threshold"],
             delete_bad_inputs = config.config["ml_trainer_delete"],
         )
 
@@ -101,6 +95,8 @@ class TrainingBlueprint(ABC):
         self.spec = None
 
     def train(self, spec: InputSpec) -> OutputSpec:
+        import smart_future
+
         random.seed()
         self.spec = spec
 
@@ -142,35 +138,36 @@ class TrainingBlueprint(ABC):
         best_test_score = None
         best_training_score = None
         best_params = None
-        for model in smart_future.wait_many(models):
+        for model in smart_future.wait_any(models):
             params = modelid_to_params[model.get_id()]
             if isinstance(model, smart_future.SmartFuture):
                 model = model._resolve()
-            training_score, test_score = self.evaluate_model(
-                model,
-                self.X_train_scaled,
-                self.y_train,
-                self.X_test_scaled,
-                self.y_test,
-            )
-            score = (training_score + test_score * 20) / 21
-            if not self.spec.quiet:
-                print(
-                    f"{bold()}{params}{reset()}: "
-                    f"Training set score={training_score:.2f}%, "
-                    f"test set score={test_score:.2f}%",
-                    file=sys.stderr,
+            if model is not None:
+                training_score, test_score = self.evaluate_model(
+                    model,
+                    self.X_train_scaled,
+                    self.y_train,
+                    self.X_test_scaled,
+                    self.y_test,
                 )
-            if best_score is None or score > best_score:
-                best_score = score
-                best_test_score = test_score
-                best_training_score = training_score
-                best_model = model
-                best_params = params
+                score = (training_score + test_score * 20) / 21
                 if not self.spec.quiet:
                     print(
-                        f"New best score {best_score:.2f}% with params {params}"
+                        f"{bold()}{params}{reset()}: "
+                        f"Training set score={training_score:.2f}%, "
+                        f"test set score={test_score:.2f}%",
+                        file=sys.stderr,
                     )
+                if best_score is None or score > best_score:
+                    best_score = score
+                    best_test_score = test_score
+                    best_training_score = training_score
+                    best_model = model
+                    best_params = params
+                    if not self.spec.quiet:
+                        print(
+                            f"New best score {best_score:.2f}% with params {params}"
+                        )
 
         if not self.spec.quiet:
             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
@@ -222,8 +219,7 @@ class TrainingBlueprint(ABC):
                 try:
                     (key, value) = line.split(self.spec.key_value_delimiter)
                 except Exception as e:
-                    logger.exception(e)
-                    print(f"WARNING: bad line '{line}', skipped")
+                    logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
                     continue
 
                 key = key.strip()
@@ -262,11 +258,17 @@ class TrainingBlueprint(ABC):
 
     def make_progress_graph(self) -> None:
         if not self.spec.quiet:
-            text_utils.progress_graph(self.file_done_count,
-                                      self.total_file_count)
+            from text_utils import progress_graph
+            progress_graph(
+                self.file_done_count,
+                self.total_file_count
+            )
 
-    @decorator_utils.timed
+    @timed
     def read_input_files(self):
+        import list_utils
+        import smart_future
+
         # All features
         X = []
 
@@ -280,7 +282,7 @@ class TrainingBlueprint(ABC):
             file_list = list(files)
             results.append(self.read_files_from_list(file_list, n))
 
-        for result in smart_future.wait_many(results, callback=self.make_progress_graph):
+        for result in smart_future.wait_any(results, callback=self.make_progress_graph):
             result = result._resolve()
             for z in result[0]:
                 X.append(z)
@@ -350,9 +352,13 @@ class TrainingBlueprint(ABC):
             scaler: Any,
             model: Any) -> Tuple[Optional[str], Optional[str], Optional[str]]:
         if not self.spec.dry_run:
+            import datetime_utils
+            import input_utils
+            import string_utils
+
             if (
-                    (self.spec.persist_predicate is not None and
-                     test_score > self.spec.persist_predicate)
+                    (self.spec.persist_percentage_threshold is not None and
+                     test_score > self.spec.persist_percentage_threshold)
                     or
                     (not self.spec.quiet
                      and input_utils.yn_response("Write the model? [y,n]: ") == "y")