projects
/
python_utils.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Added some helpers to file_utils and improved the docs/doctests.
[python_utils.git]
/
ml
/
model_trainer.py
diff --git
a/ml/model_trainer.py
b/ml/model_trainer.py
index f61b8e745b6cd6c02f23e258003928caba81916b..79ce7062b5b4a05616cddcf7b3d35d59cfbef007 100644
(file)
--- a/
ml/model_trainer.py
+++ b/
ml/model_trainer.py
@@
-12,6
+12,7
@@
import random
import sys
from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Set, Tuple
import sys
from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Set, Tuple
+import warnings
import numpy as np
from sklearn.model_selection import train_test_split # type:ignore
import numpy as np
from sklearn.model_selection import train_test_split # type:ignore
@@
-246,14
+247,14
@@
class TrainingBlueprint(ABC):
y.pop()
if self.spec.delete_bad_inputs:
y.pop()
if self.spec.delete_bad_inputs:
- msg = f"WARNING: {filename}: missing features or label. DELETING."
- print(msg, file=sys.stderr)
+ msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. DELETING."
logger.warning(msg)
logger.warning(msg)
+ warnings.warn(msg)
os.remove(filename)
else:
os.remove(filename)
else:
- msg = f"WARNING: {filename}: missing features or label. Skipped."
- print(msg, file=sys.stderr)
+ msg = f"WARNING: {filename}: missing features or label; expected {self.spec.feature_count} but saw {len(x)}. Skipping."
logger.warning(msg)
logger.warning(msg)
+ warnings.warn(msg)
return (X, y)
def make_progress_graph(self) -> None:
return (X, y)
def make_progress_graph(self) -> None: