From cce8d58af187c0a7fb7585eab5bda9fed731b719 Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Thu, 10 Feb 2022 13:36:58 -0800 Subject: [PATCH] More cleanup. --- ml/model_trainer.py | 64 +++++++++++++++++++++++---------------- smart_home/lights.py | 34 ++++++++++++--------- smart_home/outlets.py | 1 + smart_home/registry.py | 69 ++++++++++++++++++++++-------------------- 4 files changed, 95 insertions(+), 73 deletions(-) diff --git a/ml/model_trainer.py b/ml/model_trainer.py index 12ccb3c..1509577 100644 --- a/ml/model_trainer.py +++ b/ml/model_trainer.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 -from __future__ import annotations +"""This is a blueprint for training sklearn ML models.""" +from __future__ import annotations import datetime import glob import logging @@ -11,8 +12,9 @@ import random import sys import warnings from abc import ABC, abstractmethod +from dataclasses import dataclass from types import SimpleNamespace -from typing import Any, List, NamedTuple, Optional, Set, Tuple +from typing import Any, List, Optional, Set, Tuple import numpy as np from sklearn.model_selection import train_test_split # type:ignore @@ -25,7 +27,7 @@ import parallelize as par from ansi import bold, reset from decorator_utils import timed -logger = logging.getLogger(__file__) +logger = logging.getLogger(__name__) parser = config.add_commandline_args( f"ML Model Trainer ({__file__})", @@ -56,6 +58,9 @@ group.add_argument( class InputSpec(SimpleNamespace): + """A collection of info needed to train the model provided by the + caller.""" + file_glob: str feature_count: int features_to_skip: Set[str] @@ -78,15 +83,20 @@ class InputSpec(SimpleNamespace): ) -class OutputSpec(NamedTuple): - model_filename: Optional[str] - model_info_filename: Optional[str] - scaler_filename: Optional[str] - training_score: np.float64 - test_score: np.float64 +@dataclass +class OutputSpec: + """Info about the results of training returned to the caller.""" + + model_filename: Optional[str] = None + model_info_filename: Optional[str] = None + scaler_filename: Optional[str] = None + training_score: np.float64 = np.float64(0.0) + test_score: np.float64 = np.float64(0.0) class TrainingBlueprint(ABC): + """The blueprint for doing the actual training.""" + def __init__(self): self.y_train = None self.y_test = None @@ -112,13 +122,13 @@ class TrainingBlueprint(ABC): y = np.array(y_) print("Doing random test/train split...") - X_train, X_test, self.y_train, self.y_test = self.test_train_split( + X_train, X_test, self.y_train, self.y_test = TrainingBlueprint.test_train_split( X, y, ) print("Scaling training data...") - scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data( + scaler, self.X_train_scaled, self.X_test_scaled = TrainingBlueprint.scale_data( X_train, X_test, ) @@ -141,7 +151,7 @@ class TrainingBlueprint(ABC): if isinstance(model, smart_future.SmartFuture): model = model._resolve() if model is not None: - training_score, test_score = self.evaluate_model( + training_score, test_score = TrainingBlueprint.evaluate_model( model, self.X_train_scaled, self.y_train, @@ -195,7 +205,7 @@ class TrainingBlueprint(ABC): ) @par.parallelize(method=par.Method.THREAD) - def read_files_from_list(self, files: List[str], n: int) -> Tuple[List, List]: + def read_files_from_list(self, files: List[str]) -> Tuple[List, List]: # All features X = [] @@ -218,16 +228,16 @@ class TrainingBlueprint(ABC): try: (key, value) = line.split(self.spec.key_value_delimiter) except Exception: - logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped") + logger.debug("WARNING: bad line in file %s '%s', skipped", filename, line) continue key = key.strip() value = value.strip() if self.spec.features_to_skip is not None and key in self.spec.features_to_skip: - logger.debug(f"Skipping feature {key}") + logger.debug("Skipping feature %s", key) continue - value = self.normalize_feature(value) + value = TrainingBlueprint.normalize_feature(value) if key == self.spec.label: y.append(value) @@ -274,9 +284,9 @@ class TrainingBlueprint(ABC): results = [] all_files = glob.glob(self.spec.file_glob) self.total_file_count = len(all_files) - for n, files in enumerate(list_utils.shard(all_files, 500)): + for files in list_utils.shard(all_files, 500): file_list = list(files) - results.append(self.read_files_from_list(file_list, n)) + results.append(self.read_files_from_list(file_list)) for result in smart_future.wait_any(results, callback=self.make_progress_graph): result = result._resolve() @@ -288,7 +298,8 @@ class TrainingBlueprint(ABC): print(" " * 80 + "\n") return (X, y) - def normalize_feature(self, value: str) -> Any: + @staticmethod + def normalize_feature(value: str) -> Any: if value in ("False", "None"): ret = 0 elif value == "True": @@ -299,7 +310,8 @@ class TrainingBlueprint(ABC): ret = int(value) return ret - def test_train_split(self, X, y) -> List: + @staticmethod + def test_train_split(X, y) -> List: logger.debug("Performing test/train split") return train_test_split( X, @@ -307,9 +319,8 @@ class TrainingBlueprint(ABC): random_state=random.randrange(0, 1000), ) - def scale_data( - self, X_train: np.ndarray, X_test: np.ndarray - ) -> Tuple[Any, np.ndarray, np.ndarray]: + @staticmethod + def scale_data(X_train: np.ndarray, X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]: logger.debug("Scaling data") scaler = MinMaxScaler() scaler.fit(X_train) @@ -320,8 +331,8 @@ class TrainingBlueprint(ABC): def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any: pass + @staticmethod def evaluate_model( - self, model: Any, X_train_scaled: np.ndarray, y_train: np.ndarray, @@ -332,8 +343,9 @@ class TrainingBlueprint(ABC): training_score = model.score(X_train_scaled, y_train) * 100.0 test_score = model.score(X_test_scaled, y_test) * 100.0 logger.info( - f"Model evaluation results: test_score={test_score:.5f}, " - f"train_score={training_score:.5f}" + "Model evaluation results: test_score=%.5f, train_score=%.5f", + test_score, + training_score, ) return (training_score, test_score) diff --git a/smart_home/lights.py b/smart_home/lights.py index ac1cc88..908fe8d 100644 --- a/smart_home/lights.py +++ b/smart_home/lights.py @@ -55,11 +55,13 @@ def tplink_light_command(command: str) -> bool: logger.warning(msg) logging_utils.hlog(msg) return False - logger.debug(f'{command} succeeded.') + logger.debug('%s succeeded.', command) return True class BaseLight(dev.Device): + """A base class representing a smart light.""" + def __init__(self, name: str, mac: str, keywords: str = "") -> None: super().__init__(name.strip(), mac.strip(), keywords) @@ -111,8 +113,7 @@ class BaseLight(dev.Device): class GoogleLight(BaseLight): - def __init__(self, name: str, mac: str, keywords: str = "") -> None: - super().__init__(name, mac, keywords) + """A smart light controlled by talking to Google.""" def goog_name(self) -> str: name = self.get_name() @@ -187,6 +188,8 @@ class GoogleLight(BaseLight): class TuyaLight(BaseLight): + """A Tuya smart light.""" + ids_by_mac = { '68:C6:3A:DE:1A:94': '8844664268c63ade1a94', '68:C6:3A:DE:27:1A': '8844664268c63ade271a', @@ -251,14 +254,14 @@ class TuyaLight(BaseLight): @overrides def set_dimmer_level(self, level: int) -> bool: - logger.debug(f'Setting brightness to {level}') + logger.debug('Setting brightness to %d', level) self.bulb.set_brightness(level) return True @overrides def make_color(self, color: str) -> bool: rgb = BaseLight.parse_color_string(color) - logger.debug(f'Light color: {color} -> {rgb}') + logger.debug('Light color: %s -> %s', color, rgb) if rgb is not None: self.bulb.set_colour(rgb[0], rgb[1], rgb[2]) return True @@ -266,16 +269,19 @@ class TuyaLight(BaseLight): class TPLinkLight(BaseLight): + """A TPLink smart light.""" + def __init__(self, name: str, mac: str, keywords: str = "") -> None: super().__init__(name, mac, keywords) self.children: List[str] = [] self.info: Optional[Dict] = None self.info_ts: Optional[datetime.datetime] = None - if "children" in self.keywords: - self.info = self.get_info() - if self.info is not None: - for child in self.info["children"]: - self.children.append(child["id"]) + if self.keywords is not None: + if "children" in self.keywords: + self.info = self.get_info() + if self.info is not None: + for child in self.info["children"]: + self.children.append(child["id"]) @memoized def get_tplink_name(self) -> Optional[str]: @@ -300,7 +306,7 @@ class TPLinkLight(BaseLight): cmd = self.get_cmdline(child) + f"-c {cmd}" if extra_args is not None: cmd += f" {extra_args}" - logger.debug(f'About to execute {cmd}') + logger.debug('About to execute: %s', cmd) return tplink_light_command(cmd) @overrides @@ -329,14 +335,14 @@ class TPLinkLight(BaseLight): @timeout(10.0, use_signals=False, error_message="Timed out waiting for tplink.py") def get_info(self) -> Optional[Dict]: cmd = self.get_cmdline() + "-c info" - logger.debug(f'Getting status of {self.mac} via "{cmd}"...') + logger.debug('Getting status of %s via "%s"...', self.mac, cmd) out = subprocess.getoutput(cmd) - logger.debug(f'RAW OUT> {out}') + logger.debug('RAW OUT> %s', out) out = re.sub("Sent:.*\n", "", out) out = re.sub("Received: *", "", out) try: self.info = json.loads(out)["system"]["get_sysinfo"] - logger.debug(json.dumps(self.info, indent=4, sort_keys=True)) + logger.debug("%s", json.dumps(self.info, indent=4, sort_keys=True)) self.info_ts = datetime.datetime.now() return self.info except Exception as e: diff --git a/smart_home/outlets.py b/smart_home/outlets.py index 500ea05..a7f6f47 100644 --- a/smart_home/outlets.py +++ b/smart_home/outlets.py @@ -155,6 +155,7 @@ class TPLinkOutletWithChildren(TPLinkOutlet): self.children: List[str] = [] self.info: Optional[Dict] = None self.info_ts: Optional[datetime.datetime] = None + assert self.keywords is not None assert "children" in self.keywords self.info = self.get_info() if self.info is not None: diff --git a/smart_home/registry.py b/smart_home/registry.py index 16e18ba..f79ae90 100644 --- a/smart_home/registry.py +++ b/smart_home/registry.py @@ -1,18 +1,17 @@ #!/usr/bin/env python3 +"""A searchable registry of known smart home devices and a factory for +constructing our wrappers around them.""" + import logging import re -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set import argparse_utils import config import file_utils import logical_search -import smart_home.cameras as cameras -import smart_home.chromecasts as chromecasts -import smart_home.device as device -import smart_home.lights as lights -import smart_home.outlets as outlets +from smart_home import cameras, chromecasts, device, lights, outlets args = config.add_commandline_args( f"Smart Home Registry ({__file__})", @@ -27,26 +26,29 @@ args.add_argument( ) -logger = logging.getLogger(__file__) +logger = logging.getLogger(__name__) class SmartHomeRegistry(object): + """A searchable registry of known smart home devices and a factory for + constructing our wrappers around them.""" + def __init__( self, registry_file: Optional[str] = None, filters: List[str] = ['smart'], ) -> None: - self._macs_by_name = {} - self._keywords_by_name = {} - self._keywords_by_mac = {} - self._names_by_mac = {} - self._corpus = logical_search.Corpus() + self._macs_by_name: Dict[str, str] = {} + self._keywords_by_name: Dict[str, str] = {} + self._keywords_by_mac: Dict[str, str] = {} + self._names_by_mac: Dict[str, str] = {} + self._corpus: logical_search.Corpus = logical_search.Corpus() # Read the disk config file... if registry_file is None: registry_file = config.config['smart_home_registry_file_location'] assert file_utils.does_file_exist(registry_file) - logger.debug(f'Reading {registry_file}') + logger.debug('Reading %s', registry_file) with open(registry_file, "r") as rf: contents = rf.readlines() @@ -57,12 +59,11 @@ class SmartHomeRegistry(object): line = line.strip() if line == "": continue - logger.debug(f'SH-CONFIG> {line}') + logger.debug('SH-CONFIG> %s', line) try: (mac, name, keywords) = line.split(",") except ValueError: - msg = f'SH-CONFIG> "{line}" is malformed?! Skipping it.' - logger.warning(msg) + logger.warning('SH-CONFIG> "%s" is malformed?! Skipping it.', line) continue mac = mac.strip() name = name.strip() @@ -72,7 +73,7 @@ class SmartHomeRegistry(object): if filters is not None: for f in filters: if f not in keywords: - logger.debug(f'Skipping this entry b/c of filter {f}') + logger.debug('Skipping this entry b/c of filter: %s', f) skip = True break if not skip: @@ -91,14 +92,14 @@ class SmartHomeRegistry(object): properties.append((key, value)) else: tags.add(kw) - device = logical_search.Document( + dev = logical_search.Document( docid=mac, tags=tags, properties=properties, reference=None, ) - logger.debug(f'Indexing document {device}') - self._corpus.add_doc(device) + logger.debug('Indexing document: %s', dev) + self._corpus.add_doc(dev) def __repr__(self) -> str: s = "Known devices:\n" @@ -107,7 +108,7 @@ class SmartHomeRegistry(object): s += f" {name} ({mac}) => {keywords}\n" return s - def get_keywords_by_name(self, name: str) -> Optional[device.Device]: + def get_keywords_by_name(self, name: str) -> Optional[str]: return self._keywords_by_name.get(name, None) def get_macs_by_name(self, name: str) -> Set[str]: @@ -131,18 +132,18 @@ class SmartHomeRegistry(object): def get_all_devices(self) -> List[device.Device]: retval = [] - for (mac, kws) in self._keywords_by_mac.items(): + for mac, _ in self._keywords_by_mac.items(): if mac is not None: - device = self.get_device_by_mac(mac) - if device is not None: - retval.append(device) + dev = self.get_device_by_mac(mac) + if dev is not None: + retval.append(dev) return retval def get_device_by_mac(self, mac: str) -> Optional[device.Device]: if mac in self._keywords_by_mac: name = self._names_by_mac[mac] kws = self._keywords_by_mac[mac] - logger.debug(f'Found {name} -> {mac} ({kws})') + logger.debug('Found %s -> %s (%s)', name, mac, kws) try: if 'light' in kws.lower(): if 'tplink' in kws.lower(): @@ -184,11 +185,13 @@ class SmartHomeRegistry(object): except Exception as e: logger.exception(e) logger.debug( - f'Device {name} at {mac} with {kws} confused me, returning a generic Device' + 'Device %s at %s with %s confused me; returning a generic Device', + name, + mac, + kws, ) return device.Device(name, mac, kws) - msg = f'{mac} is not a known smart home device, returning None' - logger.warning(msg) + logger.warning('%s is not a known smart home device, returning None', mac) return None def query(self, query: str) -> List[device.Device]: @@ -197,12 +200,12 @@ class SmartHomeRegistry(object): Returns a list of matching lights. """ retval = [] - logger.debug(f'Executing query {query}') + logger.debug('Executing query: %s', query) results = self._corpus.query(query) if results is not None: for mac in results: if mac is not None: - device = self.get_device_by_mac(mac) - if device is not None: - retval.append(device) + dev = self.get_device_by_mac(mac) + if dev is not None: + retval.append(dev) return retval -- 2.45.2