More cleanup.
authorScott Gasch <[email protected]>
Thu, 10 Feb 2022 21:36:58 +0000 (13:36 -0800)
committerScott Gasch <[email protected]>
Thu, 10 Feb 2022 21:36:58 +0000 (13:36 -0800)
ml/model_trainer.py
smart_home/lights.py
smart_home/outlets.py
smart_home/registry.py

index 12ccb3c6c0508e61081d6791b12f9a8f5a5571f2..15095770cf71f1c7b4fd05e5d3c10a7488775792 100644 (file)
@@ -1,7 +1,8 @@
 #!/usr/bin/env python3
 
-from __future__ import annotations
+"""This is a blueprint for training sklearn ML models."""
 
+from __future__ import annotations
 import datetime
 import glob
 import logging
@@ -11,8 +12,9 @@ import random
 import sys
 import warnings
 from abc import ABC, abstractmethod
+from dataclasses import dataclass
 from types import SimpleNamespace
-from typing import Any, List, NamedTuple, Optional, Set, Tuple
+from typing import Any, List, Optional, Set, Tuple
 
 import numpy as np
 from sklearn.model_selection import train_test_split  # type:ignore
@@ -25,7 +27,7 @@ import parallelize as par
 from ansi import bold, reset
 from decorator_utils import timed
 
-logger = logging.getLogger(__file__)
+logger = logging.getLogger(__name__)
 
 parser = config.add_commandline_args(
     f"ML Model Trainer ({__file__})",
@@ -56,6 +58,9 @@ group.add_argument(
 
 
 class InputSpec(SimpleNamespace):
+    """A collection of info needed to train the model provided by the
+    caller."""
+
     file_glob: str
     feature_count: int
     features_to_skip: Set[str]
@@ -78,15 +83,20 @@ class InputSpec(SimpleNamespace):
         )
 
 
-class OutputSpec(NamedTuple):
-    model_filename: Optional[str]
-    model_info_filename: Optional[str]
-    scaler_filename: Optional[str]
-    training_score: np.float64
-    test_score: np.float64
+@dataclass
+class OutputSpec:
+    """Info about the results of training returned to the caller."""
+
+    model_filename: Optional[str] = None
+    model_info_filename: Optional[str] = None
+    scaler_filename: Optional[str] = None
+    training_score: np.float64 = np.float64(0.0)
+    test_score: np.float64 = np.float64(0.0)
 
 
 class TrainingBlueprint(ABC):
+    """The blueprint for doing the actual training."""
+
     def __init__(self):
         self.y_train = None
         self.y_test = None
@@ -112,13 +122,13 @@ class TrainingBlueprint(ABC):
         y = np.array(y_)
 
         print("Doing random test/train split...")
-        X_train, X_test, self.y_train, self.y_test = self.test_train_split(
+        X_train, X_test, self.y_train, self.y_test = TrainingBlueprint.test_train_split(
             X,
             y,
         )
 
         print("Scaling training data...")
-        scaler, self.X_train_scaled, self.X_test_scaled = self.scale_data(
+        scaler, self.X_train_scaled, self.X_test_scaled = TrainingBlueprint.scale_data(
             X_train,
             X_test,
         )
@@ -141,7 +151,7 @@ class TrainingBlueprint(ABC):
             if isinstance(model, smart_future.SmartFuture):
                 model = model._resolve()
             if model is not None:
-                training_score, test_score = self.evaluate_model(
+                training_score, test_score = TrainingBlueprint.evaluate_model(
                     model,
                     self.X_train_scaled,
                     self.y_train,
@@ -195,7 +205,7 @@ class TrainingBlueprint(ABC):
         )
 
     @par.parallelize(method=par.Method.THREAD)
-    def read_files_from_list(self, files: List[str], n: int) -> Tuple[List, List]:
+    def read_files_from_list(self, files: List[str]) -> Tuple[List, List]:
         # All features
         X = []
 
@@ -218,16 +228,16 @@ class TrainingBlueprint(ABC):
                 try:
                     (key, value) = line.split(self.spec.key_value_delimiter)
                 except Exception:
-                    logger.debug(f"WARNING: bad line in file {filename} '{line}', skipped")
+                    logger.debug("WARNING: bad line in file %s '%s', skipped", filename, line)
                     continue
 
                 key = key.strip()
                 value = value.strip()
                 if self.spec.features_to_skip is not None and key in self.spec.features_to_skip:
-                    logger.debug(f"Skipping feature {key}")
+                    logger.debug("Skipping feature %s", key)
                     continue
 
-                value = self.normalize_feature(value)
+                value = TrainingBlueprint.normalize_feature(value)
 
                 if key == self.spec.label:
                     y.append(value)
@@ -274,9 +284,9 @@ class TrainingBlueprint(ABC):
         results = []
         all_files = glob.glob(self.spec.file_glob)
         self.total_file_count = len(all_files)
-        for n, files in enumerate(list_utils.shard(all_files, 500)):
+        for files in list_utils.shard(all_files, 500):
             file_list = list(files)
-            results.append(self.read_files_from_list(file_list, n))
+            results.append(self.read_files_from_list(file_list))
 
         for result in smart_future.wait_any(results, callback=self.make_progress_graph):
             result = result._resolve()
@@ -288,7 +298,8 @@ class TrainingBlueprint(ABC):
             print(" " * 80 + "\n")
         return (X, y)
 
-    def normalize_feature(self, value: str) -> Any:
+    @staticmethod
+    def normalize_feature(value: str) -> Any:
         if value in ("False", "None"):
             ret = 0
         elif value == "True":
@@ -299,7 +310,8 @@ class TrainingBlueprint(ABC):
             ret = int(value)
         return ret
 
-    def test_train_split(self, X, y) -> List:
+    @staticmethod
+    def test_train_split(X, y) -> List:
         logger.debug("Performing test/train split")
         return train_test_split(
             X,
@@ -307,9 +319,8 @@ class TrainingBlueprint(ABC):
             random_state=random.randrange(0, 1000),
         )
 
-    def scale_data(
-        self, X_train: np.ndarray, X_test: np.ndarray
-    ) -> Tuple[Any, np.ndarray, np.ndarray]:
+    @staticmethod
+    def scale_data(X_train: np.ndarray, X_test: np.ndarray) -> Tuple[Any, np.ndarray, np.ndarray]:
         logger.debug("Scaling data")
         scaler = MinMaxScaler()
         scaler.fit(X_train)
@@ -320,8 +331,8 @@ class TrainingBlueprint(ABC):
     def train_model(self, parameters, X_train_scaled: np.ndarray, y_train: np.ndarray) -> Any:
         pass
 
+    @staticmethod
     def evaluate_model(
-        self,
         model: Any,
         X_train_scaled: np.ndarray,
         y_train: np.ndarray,
@@ -332,8 +343,9 @@ class TrainingBlueprint(ABC):
         training_score = model.score(X_train_scaled, y_train) * 100.0
         test_score = model.score(X_test_scaled, y_test) * 100.0
         logger.info(
-            f"Model evaluation results: test_score={test_score:.5f}, "
-            f"train_score={training_score:.5f}"
+            "Model evaluation results: test_score=%.5f, train_score=%.5f",
+            test_score,
+            training_score,
         )
         return (training_score, test_score)
 
index ac1cc885f94973a74f3b7dabcb41f45a9dbac957..908fe8d3ed3550c809af1678fa905b1311336728 100644 (file)
@@ -55,11 +55,13 @@ def tplink_light_command(command: str) -> bool:
             logger.warning(msg)
             logging_utils.hlog(msg)
             return False
-    logger.debug(f'{command} succeeded.')
+    logger.debug('%s succeeded.', command)
     return True
 
 
 class BaseLight(dev.Device):
+    """A base class representing a smart light."""
+
     def __init__(self, name: str, mac: str, keywords: str = "") -> None:
         super().__init__(name.strip(), mac.strip(), keywords)
 
@@ -111,8 +113,7 @@ class BaseLight(dev.Device):
 
 
 class GoogleLight(BaseLight):
-    def __init__(self, name: str, mac: str, keywords: str = "") -> None:
-        super().__init__(name, mac, keywords)
+    """A smart light controlled by talking to Google."""
 
     def goog_name(self) -> str:
         name = self.get_name()
@@ -187,6 +188,8 @@ class GoogleLight(BaseLight):
 
 
 class TuyaLight(BaseLight):
+    """A Tuya smart light."""
+
     ids_by_mac = {
         '68:C6:3A:DE:1A:94': '8844664268c63ade1a94',
         '68:C6:3A:DE:27:1A': '8844664268c63ade271a',
@@ -251,14 +254,14 @@ class TuyaLight(BaseLight):
 
     @overrides
     def set_dimmer_level(self, level: int) -> bool:
-        logger.debug(f'Setting brightness to {level}')
+        logger.debug('Setting brightness to %d', level)
         self.bulb.set_brightness(level)
         return True
 
     @overrides
     def make_color(self, color: str) -> bool:
         rgb = BaseLight.parse_color_string(color)
-        logger.debug(f'Light color: {color} -> {rgb}')
+        logger.debug('Light color: %s -> %s', color, rgb)
         if rgb is not None:
             self.bulb.set_colour(rgb[0], rgb[1], rgb[2])
             return True
@@ -266,16 +269,19 @@ class TuyaLight(BaseLight):
 
 
 class TPLinkLight(BaseLight):
+    """A TPLink smart light."""
+
     def __init__(self, name: str, mac: str, keywords: str = "") -> None:
         super().__init__(name, mac, keywords)
         self.children: List[str] = []
         self.info: Optional[Dict] = None
         self.info_ts: Optional[datetime.datetime] = None
-        if "children" in self.keywords:
-            self.info = self.get_info()
-            if self.info is not None:
-                for child in self.info["children"]:
-                    self.children.append(child["id"])
+        if self.keywords is not None:
+            if "children" in self.keywords:
+                self.info = self.get_info()
+                if self.info is not None:
+                    for child in self.info["children"]:
+                        self.children.append(child["id"])
 
     @memoized
     def get_tplink_name(self) -> Optional[str]:
@@ -300,7 +306,7 @@ class TPLinkLight(BaseLight):
         cmd = self.get_cmdline(child) + f"-c {cmd}"
         if extra_args is not None:
             cmd += f" {extra_args}"
-        logger.debug(f'About to execute {cmd}')
+        logger.debug('About to execute: %s', cmd)
         return tplink_light_command(cmd)
 
     @overrides
@@ -329,14 +335,14 @@ class TPLinkLight(BaseLight):
     @timeout(10.0, use_signals=False, error_message="Timed out waiting for tplink.py")
     def get_info(self) -> Optional[Dict]:
         cmd = self.get_cmdline() + "-c info"
-        logger.debug(f'Getting status of {self.mac} via "{cmd}"...')
+        logger.debug('Getting status of %s via "%s"...', self.mac, cmd)
         out = subprocess.getoutput(cmd)
-        logger.debug(f'RAW OUT> {out}')
+        logger.debug('RAW OUT> %s', out)
         out = re.sub("Sent:.*\n", "", out)
         out = re.sub("Received: *", "", out)
         try:
             self.info = json.loads(out)["system"]["get_sysinfo"]
-            logger.debug(json.dumps(self.info, indent=4, sort_keys=True))
+            logger.debug("%s", json.dumps(self.info, indent=4, sort_keys=True))
             self.info_ts = datetime.datetime.now()
             return self.info
         except Exception as e:
index 500ea05372dd200444ba3268427b1d6f814850c9..a7f6f47e4225610c21b09ce7e36f547ba7f39157 100644 (file)
@@ -155,6 +155,7 @@ class TPLinkOutletWithChildren(TPLinkOutlet):
         self.children: List[str] = []
         self.info: Optional[Dict] = None
         self.info_ts: Optional[datetime.datetime] = None
+        assert self.keywords is not None
         assert "children" in self.keywords
         self.info = self.get_info()
         if self.info is not None:
index 16e18ba11bcc5fa19443245f02df546c68d54787..f79ae90c15f5f7c881ae1acecaee434681aaf2aa 100644 (file)
@@ -1,18 +1,17 @@
 #!/usr/bin/env python3
 
+"""A searchable registry of known smart home devices and a factory for
+constructing our wrappers around them."""
+
 import logging
 import re
-from typing import List, Optional, Set
+from typing import Dict, List, Optional, Set
 
 import argparse_utils
 import config
 import file_utils
 import logical_search
-import smart_home.cameras as cameras
-import smart_home.chromecasts as chromecasts
-import smart_home.device as device
-import smart_home.lights as lights
-import smart_home.outlets as outlets
+from smart_home import cameras, chromecasts, device, lights, outlets
 
 args = config.add_commandline_args(
     f"Smart Home Registry ({__file__})",
@@ -27,26 +26,29 @@ args.add_argument(
 )
 
 
-logger = logging.getLogger(__file__)
+logger = logging.getLogger(__name__)
 
 
 class SmartHomeRegistry(object):
+    """A searchable registry of known smart home devices and a factory for
+    constructing our wrappers around them."""
+
     def __init__(
         self,
         registry_file: Optional[str] = None,
         filters: List[str] = ['smart'],
     ) -> None:
-        self._macs_by_name = {}
-        self._keywords_by_name = {}
-        self._keywords_by_mac = {}
-        self._names_by_mac = {}
-        self._corpus = logical_search.Corpus()
+        self._macs_by_name: Dict[str, str] = {}
+        self._keywords_by_name: Dict[str, str] = {}
+        self._keywords_by_mac: Dict[str, str] = {}
+        self._names_by_mac: Dict[str, str] = {}
+        self._corpus: logical_search.Corpus = logical_search.Corpus()
 
         # Read the disk config file...
         if registry_file is None:
             registry_file = config.config['smart_home_registry_file_location']
         assert file_utils.does_file_exist(registry_file)
-        logger.debug(f'Reading {registry_file}')
+        logger.debug('Reading %s', registry_file)
         with open(registry_file, "r") as rf:
             contents = rf.readlines()
 
@@ -57,12 +59,11 @@ class SmartHomeRegistry(object):
             line = line.strip()
             if line == "":
                 continue
-            logger.debug(f'SH-CONFIG> {line}')
+            logger.debug('SH-CONFIG> %s', line)
             try:
                 (mac, name, keywords) = line.split(",")
             except ValueError:
-                msg = f'SH-CONFIG> "{line}" is malformed?!  Skipping it.'
-                logger.warning(msg)
+                logger.warning('SH-CONFIG> "%s" is malformed?!  Skipping it.', line)
                 continue
             mac = mac.strip()
             name = name.strip()
@@ -72,7 +73,7 @@ class SmartHomeRegistry(object):
             if filters is not None:
                 for f in filters:
                     if f not in keywords:
-                        logger.debug(f'Skipping this entry b/c of filter {f}')
+                        logger.debug('Skipping this entry b/c of filter: %s', f)
                         skip = True
                         break
             if not skip:
@@ -91,14 +92,14 @@ class SmartHomeRegistry(object):
                 properties.append((key, value))
             else:
                 tags.add(kw)
-        device = logical_search.Document(
+        dev = logical_search.Document(
             docid=mac,
             tags=tags,
             properties=properties,
             reference=None,
         )
-        logger.debug(f'Indexing document {device}')
-        self._corpus.add_doc(device)
+        logger.debug('Indexing document: %s', dev)
+        self._corpus.add_doc(dev)
 
     def __repr__(self) -> str:
         s = "Known devices:\n"
@@ -107,7 +108,7 @@ class SmartHomeRegistry(object):
             s += f"  {name} ({mac}) => {keywords}\n"
         return s
 
-    def get_keywords_by_name(self, name: str) -> Optional[device.Device]:
+    def get_keywords_by_name(self, name: str) -> Optional[str]:
         return self._keywords_by_name.get(name, None)
 
     def get_macs_by_name(self, name: str) -> Set[str]:
@@ -131,18 +132,18 @@ class SmartHomeRegistry(object):
 
     def get_all_devices(self) -> List[device.Device]:
         retval = []
-        for (mac, kws) in self._keywords_by_mac.items():
+        for mac, _ in self._keywords_by_mac.items():
             if mac is not None:
-                device = self.get_device_by_mac(mac)
-                if device is not None:
-                    retval.append(device)
+                dev = self.get_device_by_mac(mac)
+                if dev is not None:
+                    retval.append(dev)
         return retval
 
     def get_device_by_mac(self, mac: str) -> Optional[device.Device]:
         if mac in self._keywords_by_mac:
             name = self._names_by_mac[mac]
             kws = self._keywords_by_mac[mac]
-            logger.debug(f'Found {name} -> {mac} ({kws})')
+            logger.debug('Found %s -> %s (%s)', name, mac, kws)
             try:
                 if 'light' in kws.lower():
                     if 'tplink' in kws.lower():
@@ -184,11 +185,13 @@ class SmartHomeRegistry(object):
             except Exception as e:
                 logger.exception(e)
                 logger.debug(
-                    f'Device {name} at {mac} with {kws} confused me, returning a generic Device'
+                    'Device %s at %s with %s confused me; returning a generic Device',
+                    name,
+                    mac,
+                    kws,
                 )
                 return device.Device(name, mac, kws)
-        msg = f'{mac} is not a known smart home device, returning None'
-        logger.warning(msg)
+        logger.warning('%s is not a known smart home device, returning None', mac)
         return None
 
     def query(self, query: str) -> List[device.Device]:
@@ -197,12 +200,12 @@ class SmartHomeRegistry(object):
         Returns a list of matching lights.
         """
         retval = []
-        logger.debug(f'Executing query {query}')
+        logger.debug('Executing query: %s', query)
         results = self._corpus.query(query)
         if results is not None:
             for mac in results:
                 if mac is not None:
-                    device = self.get_device_by_mac(mac)
-                    if device is not None:
-                        retval.append(device)
+                    dev = self.get_device_by_mac(mac)
+                    if dev is not None:
+                        retval.append(dev)
         return retval