projects
/
python_utils.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Make it clear in arg message that this is a library. Dependency
[python_utils.git]
/
ml
/
model_trainer.py
diff --git
a/ml/model_trainer.py
b/ml/model_trainer.py
index a37885ce3e92f20d11bb61f4268407689829bca4..12ccb3c6c0508e61081d6791b12f9a8f5a5571f2 100644
(file)
--- a/
ml/model_trainer.py
+++ b/
ml/model_trainer.py
@@
-28,7
+28,8
@@
from decorator_utils import timed
logger = logging.getLogger(__file__)
parser = config.add_commandline_args(
logger = logging.getLogger(__file__)
parser = config.add_commandline_args(
- f"ML Model Trainer ({__file__})", "Arguments related to training an ML model"
+ f"ML Model Trainer ({__file__})",
+ "Arguments related to training an ML model",
)
parser.add_argument(
"--ml_trainer_quiet",
)
parser.add_argument(
"--ml_trainer_quiet",
@@
-217,17
+218,12
@@
class TrainingBlueprint(ABC):
try:
(key, value) = line.split(self.spec.key_value_delimiter)
except Exception:
try:
(key, value) = line.split(self.spec.key_value_delimiter)
except Exception:
- logger.debug(
- f"WARNING: bad line in file {filename} '{line}', skipped"
- )
+ logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
continue
key = key.strip()
value = value.strip()
continue
key = key.strip()
value = value.strip()
- if (
- self.spec.features_to_skip is not None
- and key in self.spec.features_to_skip
- ):
+ if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
logger.debug(f"Skipping feature {key}")
continue
logger.debug(f"Skipping feature {key}")
continue
@@
-321,9
+317,7
@@
class TrainingBlueprint(ABC):
# Note: children should implement. Consider using @parallelize.
@abstractmethod
# Note: children should implement. Consider using @parallelize.
@abstractmethod
- def train_model(
- self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray
- ) -> Any:
+ def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
pass
def evaluate_model(
pass
def evaluate_model(
@@
-368,8
+362,7
@@
Testing set score: {test_score:.2f}%"""
self.spec.persist_percentage_threshold is not None
and test_score > self.spec.persist_percentage_threshold
) or (
self.spec.persist_percentage_threshold is not None
and test_score > self.spec.persist_percentage_threshold
) or (
- not self.spec.quiet
- and input_utils.yn_response("Write the model? [y,n]: ") == "y"
+ not self.spec.quiet and input_utils.yn_response("Write the model? [y,n]: ") == "y"
):
scaler_filename = f"{self.spec.basename}_scaler.sav"
with open(scaler_filename, "wb") as fb:
):
scaler_filename = f"{self.spec.basename}_scaler.sav"
with open(scaler_filename, "wb") as fb: