X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=ml%2Fmodel_trainer.py;fp=ml%2Fmodel_trainer.py;h=79ce7062b5b4a05616cddcf7b3d35d59cfbef007;hb=b454ad295eb3024a238d32bf2aef1ebc3c496b44;hp=acd721868a2a9e04de0da364b8d37dcc268b4fee;hpb=d2478310649d51e14f8ece57651ca9d925d98793;p=python_utils.git diff --git a/ml/model_trainer.py b/ml/model_trainer.py index acd7218..79ce706 100644 --- a/ml/model_trainer.py +++ b/ml/model_trainer.py @@ -12,6 +12,7 @@ import random import sys from types import SimpleNamespace from typing import Any, List, NamedTuple, Optional, Set, Tuple +import warnings import numpy as np from sklearn.model_selection import train_test_split # type:ignore @@ -247,13 +248,13 @@ class TrainingBlueprint(ABC): if self.spec.delete_bad_inputs: msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. DELETING." - print(msg, file=sys.stderr) logger.warning(msg) + warnings.warn(msg) os.remove(filename) else: msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. Skipping." - print(msg, file=sys.stderr) logger.warning(msg) + warnings.warn(msg) return (X, y) def make_progress_graph(self) -> None: