X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=ml%2Fmodel_trainer.py;h=79ce7062b5b4a05616cddcf7b3d35d59cfbef007;hb=2a84ca5a8c75eb7db556b962c645bed79736887b;hp=f61b8e745b6cd6c02f23e258003928caba81916b;hpb=6f688ff9bacee93679f6af45a301b4308e19764c;p=python_utils.git diff --git a/ml/model_trainer.py b/ml/model_trainer.py index f61b8e7..79ce706 100644 --- a/ml/model_trainer.py +++ b/ml/model_trainer.py @@ -12,6 +12,7 @@ import random import sys from types import SimpleNamespace from typing import Any, List, NamedTuple, Optional, Set, Tuple +import warnings import numpy as np from sklearn.model_selection import train_test_split # type:ignore @@ -246,14 +247,14 @@ class TrainingBlueprint(ABC): y.pop() if self.spec.delete_bad_inputs: - msg = f"WARNING: {filename}: missing features or label. DELETING." - print(msg, file=sys.stderr) + msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. DELETING." logger.warning(msg) + warnings.warn(msg) os.remove(filename) else: - msg = f"WARNING: {filename}: missing features or label. Skipped." - print(msg, file=sys.stderr) + msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. Skipping." logger.warning(msg) + warnings.warn(msg) return (X, y) def make_progress_graph(self) -> None: