From e5d6a97b519e8d074237e93586ba4f7be21e57ce Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Sat, 2 Apr 2022 10:50:55 -0700 Subject: [PATCH] Minor changes; make model trainer show results for all models, etc... --- base_presence.py | 2 +- ml/model_trainer.py | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/base_presence.py b/base_presence.py index f996d54..5984b41 100755 --- a/base_presence.py +++ b/base_presence.py @@ -122,7 +122,7 @@ class PresenceDetection(object): try: raw = cmd( "ssh scott@meerkat.cabin 'cat /home/scott/cron/persisted_mac_addresses.txt'", - timeout_seconds=10.0, + timeout_seconds=20.0, ) self.parse_raw_macs_file(raw, Location.CABIN) except Exception as e: diff --git a/ml/model_trainer.py b/ml/model_trainer.py index e3d89c2..34ded74 100644 --- a/ml/model_trainer.py +++ b/ml/model_trainer.py @@ -12,6 +12,7 @@ import os import pickle import random import sys +import time import warnings from abc import ABC, abstractmethod from dataclasses import dataclass @@ -143,6 +144,7 @@ class TrainingBlueprint(ABC): models.append(model) modelid_to_params[model.get_id()] = str(params) + all_models = {} best_model = None best_score: Optional[np.float64] = None best_test_score: Optional[np.float64] = None @@ -161,6 +163,7 @@ class TrainingBlueprint(ABC): self.y_test, ) score = (training_score + test_score * 20) / 21 + all_models[params] = (score, training_score, test_score) if not self.spec.quiet: print( f"{bold()}{params}{reset()}: " @@ -177,15 +180,22 @@ class TrainingBlueprint(ABC): if not self.spec.quiet: print(f"New best score {best_score:.2f}% with params {params}") - if not self.spec.quiet: - executors.DefaultExecutors().shutdown() - msg = f"Done training; best test set score was: {best_test_score:.1f}%" - print(msg) - logger.info(msg) - + executors.DefaultExecutors().shutdown() assert best_training_score is not None assert best_test_score is not None assert best_params is not None + + if not self.spec.quiet: + time.sleep(1.0) + print('Done training...') + for params in all_models: + msg = f'{bold()}{params}{reset()}: score={all_models[params][0]:.2f}% ' + msg += f'({all_models[params][2]:.2f}% test, ' + msg += f'{all_models[params][1]:.2f}% train)' + if params == best_params: + msg += f'{bold()} <-- winner{reset()}' + print(msg) + ( scaler_filename, model_filename, -- 2.47.1